Load checkpoint without re-creating the model (#11318)

This commit is contained in:
Sylvain Gugger
2021-04-19 20:31:29 -04:00
committed by GitHub
parent 95037a169f
commit c0328a6c26
3 changed files with 61 additions and 12 deletions

View File

@@ -725,6 +725,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
def test_resume_training_with_frozen_params(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.model.a.requires_grad_(False)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.model.a.requires_grad_(False)
trainer.train(resume_from_checkpoint=checkpoint)
self.assertFalse(trainer.model.a.requires_grad)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)
with tempfile.TemporaryDirectory() as tmpdir: