make failure to find a resume checkpoint fatal + tests (#10777)

This commit is contained in:
Stas Bekman
2021-03-17 11:16:37 -07:00
committed by GitHub
parent cd8c93f701
commit 3318c246f3
2 changed files with 28 additions and 13 deletions

View File

@@ -613,7 +613,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
return
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer = get_regression_trainer(**kwargs)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
@@ -621,7 +622,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer = get_regression_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
@@ -634,7 +635,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer = get_regression_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
@@ -645,9 +646,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
)
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False)
trainer = get_regression_trainer(**kwargs)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
@@ -655,9 +656,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer and load model
trainer = get_regression_trainer(
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
)
trainer = get_regression_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
@@ -670,9 +669,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
checkpoint = os.path.join(tmpdir, "checkpoint-15")
# Reinitialize trainer and load model
trainer = get_regression_trainer(
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
)
trainer = get_regression_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
@@ -681,6 +678,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
# Now check failures
# 1. fail to find a bogus checkpoint
trainer = get_regression_trainer()
with self.assertRaises(Exception) as context:
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
self.assertTrue("Can't find a valid checkpoint at" in str(context.exception))
# 2. fail to find any checkpoint - due a fresh output_dir
output_dir2 = self.get_auto_remove_tmp_dir()
trainer = get_regression_trainer(output_dir=output_dir2)
with self.assertRaises(Exception) as context:
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of