Deprecate model_path in Trainer.train (#9854)

This commit is contained in:
Sylvain Gugger
2021-01-28 08:32:46 -05:00
committed by GitHub
parent 2ee9f9b69e
commit b4e559cfa1
14 changed files with 96 additions and 78 deletions

View File

@@ -581,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# Reinitialize trainer
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train(model_path=checkpoint)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
@@ -594,7 +594,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# Reinitialize trainer and load model
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train(model_path=checkpoint)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
@@ -617,7 +617,7 @@ class TrainerIntegrationTest(unittest.TestCase):
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
)
trainer.train(model_path=checkpoint)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
@@ -632,7 +632,7 @@ class TrainerIntegrationTest(unittest.TestCase):
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
)
trainer.train(model_path=checkpoint)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
@@ -670,7 +670,7 @@ class TrainerIntegrationTest(unittest.TestCase):
learning_rate=0.1,
)
trainer.train(model_path=checkpoint)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)