When resuming training from checkpoint, Trainer loads model (#9818)
* Whenresuming training from checkpoint, Trainer loads model * Finish cleaning tests * Address review comment * Use global_step from state
This commit is contained in:
@@ -578,9 +578,8 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
@@ -593,8 +592,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
@@ -615,10 +613,9 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionModel()
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||
model.load_state_dict(state_dict)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
@@ -631,10 +628,9 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionModel()
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||
model.load_state_dict(state_dict)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
@@ -664,9 +660,15 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
train_len=128,
|
||||
gradient_accumulation_steps=2,
|
||||
per_device_train_batch_size=4,
|
||||
save_steps=5,
|
||||
learning_rate=0.1,
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
|
||||
Reference in New Issue
Block a user