From 218d552f306fefc34b60fee8135b976c3ab5807f Mon Sep 17 00:00:00 2001 From: Volodymyr Byno Date: Thu, 13 May 2021 23:11:12 +0300 Subject: [PATCH] Fix loading the best model on the last stage of training (#11718) --- src/transformers/trainer.py | 26 +++++++++++++------------- tests/test_modeling_common.py | 3 ++- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8d79fe14ec..606a137a3e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1059,18 +1059,7 @@ class Trainer: # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") # If the model is on the GPU, it still works! - load_result = self.model.load_state_dict(state_dict, strict=False) - if len(load_result.missing_keys) != 0: - if load_result.missing_keys == self.model._keys_to_ignore_on_save: - self.model.tie_weights() - else: - logger.warn( - f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}." - ) - if len(load_result.unexpected_keys) != 0: - logger.warn( - f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." - ) + self._load_state_dict_in_model(state_dict) # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: @@ -1363,7 +1352,7 @@ class Trainer: # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu") # If the model is on the GPU, it still works! - self.model.load_state_dict(state_dict) + self._load_state_dict_in_model(state_dict) if self.deepspeed: self.deepspeed.load_checkpoint( @@ -1385,6 +1374,17 @@ class Trainer: return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics) + def _load_state_dict_in_model(self, state_dict): + load_result = self.model.load_state_dict(state_dict, strict=False) + + if len(load_result.missing_keys) != 0: + if set(load_result.missing_keys) == set(self.model._keys_to_ignore_on_save): + self.model.tie_weights() + else: + logger.warn(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") + if len(load_result.unexpected_keys) != 0: + logger.warn(f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.") + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): if self.control.should_log: logs: Dict[str, float] = {} diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 00b8080ff9..3ff21b1d5a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -180,7 +180,8 @@ class ModelTesterMixin: # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer. load_result = model.load_state_dict(state_dict_saved, strict=False) self.assertTrue( - len(load_result.missing_keys) == 0 or load_result.missing_keys == model._keys_to_ignore_on_save + len(load_result.missing_keys) == 0 + or set(load_result.missing_keys) == set(model._keys_to_ignore_on_save) ) self.assertTrue(len(load_result.unexpected_keys) == 0)