Better support for resuming training (#8878)
This commit is contained in:
@@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(b, b1)
|
||||
self.assertEqual(state, state1)
|
||||
|
||||
# Now check with a later checkpoint that it also works when we span over one epoch
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.assertEqual(state, state1)
|
||||
|
||||
# With a regular model that is not a PreTrainedModel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
@@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(b, b1)
|
||||
self.assertEqual(state, state1)
|
||||
|
||||
# Now check with a later checkpoint that it also works when we span over one epoch
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionModel()
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||
model.load_state_dict(state_dict)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.assertEqual(state, state1)
|
||||
|
||||
def test_resume_training_with_gradient_accumulation(self):
|
||||
if torch.cuda.device_count() > 2:
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
|
||||
Reference in New Issue
Block a user