From 49b77b89ea1e89a9940f2b84da1bcc0696ecb07a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 2 Nov 2022 09:53:37 -0400 Subject: [PATCH] Quality (#20002) --- src/transformers/modeling_utils.py | 3 +- tests/test_modeling_common.py | 52 ++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 152e52fe25..1c28d00ebd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2467,7 +2467,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix start_prefix = cls.base_model_prefix + "." if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: model_to_load = getattr(model, cls.base_model_prefix) - if any(key in expected_keys_not_prefixed for key in loaded_keys): + base_model_expected_keys = list(model_to_load.state_dict().keys()) + if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): raise ValueError( "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " "properly saved?" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5d299f5c83..6a0d3b7dc9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -117,6 +117,36 @@ if is_torch_available(): ) from transformers.modeling_utils import shard_checkpoint + # Fake pretrained models for tests + class BaseModel(PreTrainedModel): + config_class = PretrainedConfig + + def __init__(self, config): + super().__init__(config) + self.linear = nn.Linear(4, 5) + self.linear_2 = nn.Linear(5, 6) + + def forward(self, x): + return self.linear_2(self.linear(x)) + + class ModelWithHead(PreTrainedModel): + base_model_prefix = "base" + config_class = PretrainedConfig + + def _init_weights(self, module): + pass + + def __init__(self, config): + super().__init__(config) + self.base = BaseModel(config) + # linear is a common name between Base and Head on purpose. + self.linear = nn.Linear(6, 3) + self.linear2 = nn.Linear(3, 5) + + def forward(self, x): + return self.linear2(self.linear(self.base(x))) + + if is_tf_available(): import tensorflow as tf @@ -3039,6 +3069,28 @@ class ModelUtilsTest(TestCasePlus): for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()): self.assertTrue(torch.allclose(p1, p2)) + def test_base_model_to_head_model_load(self): + base_model = BaseModel(PretrainedConfig()) + with tempfile.TemporaryDirectory() as tmp_dir: + base_model.save_pretrained(tmp_dir) + + # Can load a base model in a model with head + model = ModelWithHead.from_pretrained(tmp_dir) + for p1, p2 in zip(model.base.parameters(), base_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + + # It doesn't work if the state dict has a mix of keys of the head and base without prefix though. + base_state_dict = base_model.state_dict() + head_state_dict = model.state_dict() + base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"] + base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"] + torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME)) + + with self.assertRaisesRegex( + ValueError, "The state dictionary of the model you are trying to load is corrupted." + ): + _ = ModelWithHead.from_pretrained(tmp_dir) + @require_torch @is_staging_test