From 35d55b7b8484c0baf9448cf57c6f4e987b306fb5 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 27 Jan 2021 09:31:18 -0500 Subject: [PATCH] When resuming training from checkpoint, Trainer loads model (#9818) * Whenresuming training from checkpoint, Trainer loads model * Finish cleaning tests * Address review comment * Use global_step from state --- src/transformers/trainer.py | 29 ++++++++++++++++++++--------- tests/test_trainer.py | 34 ++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 25 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 701657fdb4..7472712c74 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -688,20 +688,31 @@ class Trainer: self._hp_search_setup(trial) # Model re-init + model_reloaded = False if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. set_seed(self.args.seed) - - model = self.call_model_init(trial) - if not self.is_model_parallel: - model = model.to(self.args.device) - - self.model = model - self.model_wrapped = model - + self.model = self.call_model_init(trial) + model_reloaded = True # Reinitializes optimizer and scheduler self.optimizer, self.lr_scheduler = None, None + # Load potential model checkpoint + if model_path is not None and os.path.isfile(os.path.join(model_path, WEIGHTS_NAME)): + logger.info(f"Loading model from {model_path}).") + if isinstance(self.model, PreTrainedModel): + self.model = self.model.from_pretrained(model_path) + model_reloaded = True + else: + state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME)) + self.model.load_state_dict(state_dict) + + # If model was re-initialized, put it on the right device and update self.model_wrapped + if model_reloaded: + if not self.is_model_parallel: + self.model = self.model.to(self.args.device) + self.model_wrapped = self.model + # Keeping track whether we can can len() on the dataset or not train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) @@ -849,7 +860,7 @@ class Trainer: tr_loss = torch.tensor(0.0).to(self.args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 - self._globalstep_last_logged = 0 + self._globalstep_last_logged = self.state.global_step self._total_flos = self.state.total_flos model.zero_grad() diff --git a/tests/test_trainer.py b/tests/test_trainer.py index b30e87c997..9321e30e8a 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -578,9 +578,8 @@ class TrainerIntegrationTest(unittest.TestCase): checkpoint = os.path.join(tmpdir, "checkpoint-5") - # Reinitialize trainer and load model - model = RegressionPreTrainedModel.from_pretrained(checkpoint) - trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + # Reinitialize trainer + trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) trainer.train(model_path=checkpoint) (a1, b1) = trainer.model.a.item(), trainer.model.b.item() @@ -593,8 +592,7 @@ class TrainerIntegrationTest(unittest.TestCase): checkpoint = os.path.join(tmpdir, "checkpoint-15") # Reinitialize trainer and load model - model = RegressionPreTrainedModel.from_pretrained(checkpoint) - trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) trainer.train(model_path=checkpoint) (a1, b1) = trainer.model.a.item(), trainer.model.b.item() @@ -615,10 +613,9 @@ class TrainerIntegrationTest(unittest.TestCase): checkpoint = os.path.join(tmpdir, "checkpoint-5") # Reinitialize trainer and load model - model = RegressionModel() - state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) - model.load_state_dict(state_dict) - trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + trainer = get_regression_trainer( + output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False + ) trainer.train(model_path=checkpoint) (a1, b1) = trainer.model.a.item(), trainer.model.b.item() @@ -631,10 +628,9 @@ class TrainerIntegrationTest(unittest.TestCase): checkpoint = os.path.join(tmpdir, "checkpoint-15") # Reinitialize trainer and load model - model = RegressionModel() - state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) - model.load_state_dict(state_dict) - trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + trainer = get_regression_trainer( + output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False + ) trainer.train(model_path=checkpoint) (a1, b1) = trainer.model.a.item(), trainer.model.b.item() @@ -664,9 +660,15 @@ class TrainerIntegrationTest(unittest.TestCase): checkpoint = os.path.join(tmpdir, "checkpoint-5") - # Reinitialize trainer and load model - model = RegressionPreTrainedModel.from_pretrained(checkpoint) - trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + # Reinitialize trainer + trainer = get_regression_trainer( + output_dir=tmpdir, + train_len=128, + gradient_accumulation_steps=2, + per_device_train_batch_size=4, + save_steps=5, + learning_rate=0.1, + ) trainer.train(model_path=checkpoint) (a1, b1) = trainer.model.a.item(), trainer.model.b.item()