From 0001d056861bb1ec7bd6a825006f578629a101fc Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 5 Mar 2020 17:01:54 -0500 Subject: [PATCH] Correct missing keys + test (#3143) --- src/transformers/modeling_utils.py | 9 +++++++++ tests/test_modeling_common.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7dd0e873dc..203d5e8057 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -539,6 +539,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) + + if model.__class__.__name__ != model_to_load.__class__.__name__: + base_model_state_dict = model_to_load.state_dict().keys() + head_model_state_dict_without_base_prefix = [ + key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() + ] + + missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) + if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}".format( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9ba00d2421..a52d746947 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -526,6 +526,21 @@ class ModelTesterMixin: x = model.get_output_embeddings() self.assertTrue(x is None or isinstance(x, torch.nn.Linear)) + def test_correct_missing_keys(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + base_model_prefix = model.base_model_prefix + + if hasattr(model, base_model_prefix): + with tempfile.TemporaryDirectory() as temp_dir_name: + model.base_model.save_pretrained(temp_dir_name) + model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True) + + with self.subTest(msg="Missing keys for {}".format(model.__class__.__name__)): + self.assertGreater(len(loading_info["missing_keys"]), 0) + def test_tie_model_weights(self): if not self.test_torchscript: return