make failure to find a resume checkpoint fatal + tests (#10777)
This commit is contained in:
@@ -613,7 +613,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
@@ -621,7 +622,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
@@ -634,7 +635,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
@@ -645,9 +646,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
# With a regular model that is not a PreTrainedModel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False)
|
||||
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
@@ -655,9 +656,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
@@ -670,9 +669,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
@@ -681,6 +678,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
# Now check failures
|
||||
|
||||
# 1. fail to find a bogus checkpoint
|
||||
trainer = get_regression_trainer()
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
|
||||
self.assertTrue("Can't find a valid checkpoint at" in str(context.exception))
|
||||
|
||||
# 2. fail to find any checkpoint - due a fresh output_dir
|
||||
output_dir2 = self.get_auto_remove_tmp_dir()
|
||||
trainer = get_regression_trainer(output_dir=output_dir2)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||
|
||||
def test_resume_training_with_gradient_accumulation(self):
|
||||
if torch.cuda.device_count() > 2:
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
|
||||
Reference in New Issue
Block a user