make failure to find a resume checkpoint fatal + tests (#10777)
This commit is contained in:
@@ -876,7 +876,10 @@ class Trainer:
|
|||||||
if resume_from_checkpoint is None:
|
if resume_from_checkpoint is None:
|
||||||
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
|
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
|
||||||
|
|
||||||
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
if resume_from_checkpoint is not None:
|
||||||
|
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
||||||
|
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
|
||||||
|
|
||||||
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||||
|
|
||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
|
|||||||
@@ -613,7 +613,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
return
|
return
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||||
|
trainer = get_regression_trainer(**kwargs)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
state = dataclasses.asdict(trainer.state)
|
state = dataclasses.asdict(trainer.state)
|
||||||
@@ -621,7 +622,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
# Reinitialize trainer
|
# Reinitialize trainer
|
||||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
trainer = get_regression_trainer(**kwargs)
|
||||||
|
|
||||||
trainer.train(resume_from_checkpoint=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
@@ -634,7 +635,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||||
|
|
||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer and load model
|
||||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
trainer = get_regression_trainer(**kwargs)
|
||||||
|
|
||||||
trainer.train(resume_from_checkpoint=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
@@ -645,9 +646,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
# With a regular model that is not a PreTrainedModel
|
# With a regular model that is not a PreTrainedModel
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
trainer = get_regression_trainer(
|
kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False)
|
||||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
|
||||||
)
|
trainer = get_regression_trainer(**kwargs)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
state = dataclasses.asdict(trainer.state)
|
state = dataclasses.asdict(trainer.state)
|
||||||
@@ -655,9 +656,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer and load model
|
||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(**kwargs)
|
||||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train(resume_from_checkpoint=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
@@ -670,9 +669,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||||
|
|
||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer and load model
|
||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(**kwargs)
|
||||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train(resume_from_checkpoint=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
@@ -681,6 +678,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(b, b1)
|
self.assertEqual(b, b1)
|
||||||
self.check_trainer_state_are_the_same(state, state1)
|
self.check_trainer_state_are_the_same(state, state1)
|
||||||
|
|
||||||
|
# Now check failures
|
||||||
|
|
||||||
|
# 1. fail to find a bogus checkpoint
|
||||||
|
trainer = get_regression_trainer()
|
||||||
|
with self.assertRaises(Exception) as context:
|
||||||
|
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
|
||||||
|
self.assertTrue("Can't find a valid checkpoint at" in str(context.exception))
|
||||||
|
|
||||||
|
# 2. fail to find any checkpoint - due a fresh output_dir
|
||||||
|
output_dir2 = self.get_auto_remove_tmp_dir()
|
||||||
|
trainer = get_regression_trainer(output_dir=output_dir2)
|
||||||
|
with self.assertRaises(Exception) as context:
|
||||||
|
trainer.train(resume_from_checkpoint=True)
|
||||||
|
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||||
|
|
||||||
def test_resume_training_with_gradient_accumulation(self):
|
def test_resume_training_with_gradient_accumulation(self):
|
||||||
if torch.cuda.device_count() > 2:
|
if torch.cuda.device_count() > 2:
|
||||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||||
|
|||||||
Reference in New Issue
Block a user