From c0328a6c263494fff527fac7288faa627e3267e0 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 19 Apr 2021 20:31:29 -0400 Subject: [PATCH] Load checkpoint without re-creating the model (#11318) --- src/transformers/configuration_utils.py | 2 +- src/transformers/trainer.py | 31 ++++++++++++------- tests/test_trainer.py | 40 +++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 12 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 2b08d10b24..3aa671251c 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -271,7 +271,7 @@ class PretrainedConfig(object): self._name_or_path = str(kwargs.pop("name_or_path", "")) # Drop the transformers version info - kwargs.pop("transformers_version", None) + self.transformers_version = kwargs.pop("transformers_version", None) # Additional attributes without default values for key, value in kwargs.items(): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a6e6e81e43..a0d4440f2f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -55,9 +55,12 @@ from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler +from . import __version__ +from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .dependency_versions_check import dep_version_check from .file_utils import ( + CONFIG_NAME, WEIGHTS_NAME, is_apex_available, is_datasets_available, @@ -999,14 +1002,23 @@ class Trainer: logger.info(f"Loading model from {resume_from_checkpoint}).") + if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): + config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) + checkpoint_version = config.transformers_version + if checkpoint_version is not None and checkpoint_version != __version__: + logger.warn( + f"You are resuming training from a checkpoint trained with {checkpoint_version} of " + f"Transformers but your current version is {__version__}. This is not recommended and could " + "yield to errors or unwanted behaviors." + ) + if self.deepspeed: # will be resumed in deepspeed_init pass - elif isinstance(self.model, PreTrainedModel): - self.model = self.model.from_pretrained(resume_from_checkpoint) - model_reloaded = True else: - state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + # If the model is on the GPU, it still works! self.model.load_state_dict(state_dict) # If model was re-initialized, put it on the right device and update self.model_wrapped @@ -1293,13 +1305,10 @@ class Trainer: logger.info( f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." ) - if isinstance(self.model, PreTrainedModel): - self.model = self.model.from_pretrained(self.state.best_model_checkpoint) - if self.place_model_on_device: - self.model = self.model.to(args.device) - else: - state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) - self.model.load_state_dict(state_dict) + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu") + # If the model is on the GPU, it still works! + self.model.load_state_dict(state_dict) if self.deepspeed: self.deepspeed.load_checkpoint( diff --git a/tests/test_trainer.py b/tests/test_trainer.py index f3ebf14a87..b5071783f2 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -725,6 +725,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(b, b1) self.check_trainer_state_are_the_same(state, state1) + def test_resume_training_with_frozen_params(self): + if torch.cuda.device_count() > 2: + # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of + # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model + # won't be the same since the training dataloader is shuffled). + return + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + train_len=128, + per_device_train_batch_size=4, + save_steps=5, + learning_rate=0.1, + ) + trainer.model.a.requires_grad_(False) + trainer.train() + (a, b) = trainer.model.a.item(), trainer.model.b.item() + state = dataclasses.asdict(trainer.state) + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + + # Reinitialize trainer + trainer = get_regression_trainer( + output_dir=tmpdir, + train_len=128, + per_device_train_batch_size=4, + save_steps=5, + learning_rate=0.1, + ) + trainer.model.a.requires_grad_(False) + + trainer.train(resume_from_checkpoint=checkpoint) + + self.assertFalse(trainer.model.a.requires_grad) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.check_trainer_state_are_the_same(state, state1) + def test_load_best_model_at_end(self): total = int(self.n_epochs * 64 / self.batch_size) with tempfile.TemporaryDirectory() as tmpdir: