diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a2c56f8a8b..84abde3e3f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1005,6 +1005,7 @@ class Trainer: kwargs: Additional keyword arguments used to hide deprecated arguments """ + resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint # memory metrics - must set up as early as possible self._memory_tracker.start() diff --git a/tests/test_trainer.py b/tests/test_trainer.py index c94b61751c..97ad249e6f 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -827,6 +827,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertAlmostEqual(a, a1, delta=1e-8) self.assertAlmostEqual(b, b1, delta=1e-8) + # regression for this issue: https://github.com/huggingface/transformers/issues/12970 + def test_training_with_resume_from_checkpoint_flase(self): + train_dataset = RegressionDataset(length=128) + eval_dataset = RegressionDataset() + + config = RegressionModelConfig(a=0, b=2) + model = RegressionRandomPreTrainedModel(config) + + tmp_dir = self.get_auto_remove_tmp_dir() + args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1) + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + + trainer.train(resume_from_checkpoint=False) + @require_torch_up_to_2_gpus def test_resume_training_with_gradient_accumulation(self): # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of