When resuming training from checkpoint, Trainer loads model (#9818)
* Whenresuming training from checkpoint, Trainer loads model * Finish cleaning tests * Address review comment * Use global_step from state
This commit is contained in:
@@ -688,20 +688,31 @@ class Trainer:
|
|||||||
self._hp_search_setup(trial)
|
self._hp_search_setup(trial)
|
||||||
|
|
||||||
# Model re-init
|
# Model re-init
|
||||||
|
model_reloaded = False
|
||||||
if self.model_init is not None:
|
if self.model_init is not None:
|
||||||
# Seed must be set before instantiating the model when using model_init.
|
# Seed must be set before instantiating the model when using model_init.
|
||||||
set_seed(self.args.seed)
|
set_seed(self.args.seed)
|
||||||
|
self.model = self.call_model_init(trial)
|
||||||
model = self.call_model_init(trial)
|
model_reloaded = True
|
||||||
if not self.is_model_parallel:
|
|
||||||
model = model.to(self.args.device)
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.model_wrapped = model
|
|
||||||
|
|
||||||
# Reinitializes optimizer and scheduler
|
# Reinitializes optimizer and scheduler
|
||||||
self.optimizer, self.lr_scheduler = None, None
|
self.optimizer, self.lr_scheduler = None, None
|
||||||
|
|
||||||
|
# Load potential model checkpoint
|
||||||
|
if model_path is not None and os.path.isfile(os.path.join(model_path, WEIGHTS_NAME)):
|
||||||
|
logger.info(f"Loading model from {model_path}).")
|
||||||
|
if isinstance(self.model, PreTrainedModel):
|
||||||
|
self.model = self.model.from_pretrained(model_path)
|
||||||
|
model_reloaded = True
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME))
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||||
|
if model_reloaded:
|
||||||
|
if not self.is_model_parallel:
|
||||||
|
self.model = self.model.to(self.args.device)
|
||||||
|
self.model_wrapped = self.model
|
||||||
|
|
||||||
# Keeping track whether we can can len() on the dataset or not
|
# Keeping track whether we can can len() on the dataset or not
|
||||||
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
|
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
|
||||||
|
|
||||||
@@ -849,7 +860,7 @@ class Trainer:
|
|||||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||||
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
|
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
|
||||||
self._total_loss_scalar = 0.0
|
self._total_loss_scalar = 0.0
|
||||||
self._globalstep_last_logged = 0
|
self._globalstep_last_logged = self.state.global_step
|
||||||
self._total_flos = self.state.total_flos
|
self._total_flos = self.state.total_flos
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
|
|
||||||
|
|||||||
@@ -578,9 +578,8 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer
|
||||||
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(model_path=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
@@ -593,8 +592,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||||
|
|
||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer and load model
|
||||||
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(model_path=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
@@ -615,10 +613,9 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer and load model
|
||||||
model = RegressionModel()
|
trainer = get_regression_trainer(
|
||||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||||
model.load_state_dict(state_dict)
|
)
|
||||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(model_path=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
@@ -631,10 +628,9 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
checkpoint = os.path.join(tmpdir, "checkpoint-15")
|
||||||
|
|
||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer and load model
|
||||||
model = RegressionModel()
|
trainer = get_regression_trainer(
|
||||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||||
model.load_state_dict(state_dict)
|
)
|
||||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(model_path=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
@@ -664,9 +660,15 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer
|
||||||
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
trainer = get_regression_trainer(
|
||||||
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
output_dir=tmpdir,
|
||||||
|
train_len=128,
|
||||||
|
gradient_accumulation_steps=2,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
save_steps=5,
|
||||||
|
learning_rate=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(model_path=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
|||||||
Reference in New Issue
Block a user