fix Trainer.train(resume_from_checkpoint=False) is causing an exception (#12981)
* fix #12970 * Update tests/test_trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove unnecessary issue link * fix test formatting Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1005,6 +1005,7 @@ class Trainer:
|
|||||||
kwargs:
|
kwargs:
|
||||||
Additional keyword arguments used to hide deprecated arguments
|
Additional keyword arguments used to hide deprecated arguments
|
||||||
"""
|
"""
|
||||||
|
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
|
||||||
|
|
||||||
# memory metrics - must set up as early as possible
|
# memory metrics - must set up as early as possible
|
||||||
self._memory_tracker.start()
|
self._memory_tracker.start()
|
||||||
|
|||||||
@@ -827,6 +827,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertAlmostEqual(a, a1, delta=1e-8)
|
self.assertAlmostEqual(a, a1, delta=1e-8)
|
||||||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||||
|
|
||||||
|
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||||
|
def test_training_with_resume_from_checkpoint_flase(self):
|
||||||
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
eval_dataset = RegressionDataset()
|
||||||
|
|
||||||
|
config = RegressionModelConfig(a=0, b=2)
|
||||||
|
model = RegressionRandomPreTrainedModel(config)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
|
||||||
|
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=False)
|
||||||
|
|
||||||
@require_torch_up_to_2_gpus
|
@require_torch_up_to_2_gpus
|
||||||
def test_resume_training_with_gradient_accumulation(self):
|
def test_resume_training_with_gradient_accumulation(self):
|
||||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||||
|
|||||||
Reference in New Issue
Block a user