Clean the Trainer state (#7490)
* Trainer should not modify its TrainingArguments * Trainer should not modify its TrainingArguments * Trainer should not modify its TrainingArguments * Add test of resumed training * Fixes * Non multiGPU test * Clean Trainer state * Add more to the state * Documentation * One last test * Make resume training test more complete * Unwanted changes
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import json
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -22,6 +22,7 @@ if is_torch_available():
|
||||
LineByLineTextDataset,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
TrainerState,
|
||||
)
|
||||
|
||||
|
||||
@@ -155,7 +156,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(model.b, b))
|
||||
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"]
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||
if is_pretrained:
|
||||
file_list.append("config.json")
|
||||
for step in range(freq, total, freq):
|
||||
@@ -168,7 +169,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
|
||||
):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
||||
log_history = json.load(open(os.path.join(checkpoint, "log_history.json")))
|
||||
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
|
||||
|
||||
values = [d[metric] for d in log_history]
|
||||
best_value = max(values) if greater_is_better else min(values)
|
||||
@@ -188,6 +189,12 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
metrics = trainer.evaluate()
|
||||
self.assertEqual(metrics[metric], best_value)
|
||||
|
||||
def test_training_arguments_are_left_untouched(self):
|
||||
trainer = get_regression_trainer()
|
||||
trainer.train()
|
||||
args = TrainingArguments("./regression")
|
||||
self.assertEqual(args.to_dict(), trainer.args.to_dict())
|
||||
|
||||
def test_reproducible_training(self):
|
||||
# Checks that training worked, model trained and seed made a reproducible training.
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
@@ -368,6 +375,55 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||
|
||||
def test_can_resume_training(self):
|
||||
if torch.cuda.device_count() > 2:
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
return
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.assertEqual(state, state1)
|
||||
|
||||
# With a regular model that is not a PreTrainedModel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionModel()
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||
model.load_state_dict(state_dict)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.assertEqual(state, state1)
|
||||
|
||||
def test_load_best_model_at_end(self):
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
Reference in New Issue
Block a user