Deprecate model_path in Trainer.train (#9854)
This commit is contained in:
@@ -581,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
@@ -594,7 +594,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# Reinitialize trainer and load model
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
@@ -617,7 +617,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
@@ -632,7 +632,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
@@ -670,7 +670,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
learning_rate=0.1,
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
|
||||
Reference in New Issue
Block a user