Better support for resuming training (#8878)

This commit is contained in:
Sylvain Gugger
2020-12-01 13:45:21 -05:00
committed by GitHub
parent 21db560df3
commit 7c10dd22ae
3 changed files with 65 additions and 11 deletions

View File

@@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertEqual(b, b1)
self.assertEqual(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
@@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertEqual(b, b1)
self.assertEqual(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model
model = RegressionModel()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
model.load_state_dict(state_dict)
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
trainer.train(model_path=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of