From d1fcc90abf34cc498c8a65a717ad0d9354ceca97 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 24 Feb 2022 11:43:51 +0100 Subject: [PATCH] Fix from_pretrained with default base_model_prefix (#15814) --- src/transformers/modeling_utils.py | 12 ++++++++---- tests/test_modeling_common.py | 5 ++++- utils/test_module/custom_modeling.py | 2 -- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 10a313065b..680bc695bd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1580,8 +1580,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix loaded_keys = list(state_dict.keys()) prefix = model.base_model_prefix - has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) - expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + if len(prefix) > 0: + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + else: + has_prefix_module = False + expects_prefix_module = False # key re-naming operations are never done on the keys # that are loaded, but always on the keys of the newly initialized model @@ -1669,9 +1673,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model - if not hasattr(model, cls.base_model_prefix) and has_prefix_module: + if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: start_prefix = cls.base_model_prefix + "." - if hasattr(model, cls.base_model_prefix) and not has_prefix_module: + if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: model_to_load = getattr(model, cls.base_model_prefix) if any(key in expected_keys_not_prefixed for key in loaded_keys): raise ValueError( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 17888bcfac..b6ec0eae87 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2105,7 +2105,10 @@ class ModelUtilsTest(TestCasePlus): with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) - model = NoSuperInitModel.from_pretrained(tmp_dir) + new_model = NoSuperInitModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) @require_torch diff --git a/utils/test_module/custom_modeling.py b/utils/test_module/custom_modeling.py index 07c078494d..4b64b4a3df 100644 --- a/utils/test_module/custom_modeling.py +++ b/utils/test_module/custom_modeling.py @@ -7,7 +7,6 @@ from .custom_configuration import CustomConfig, NoSuperInitConfig class CustomModel(PreTrainedModel): config_class = CustomConfig - base_model_prefix = "custom" def __init__(self, config): super().__init__(config) @@ -22,7 +21,6 @@ class CustomModel(PreTrainedModel): class NoSuperInitModel(PreTrainedModel): config_class = NoSuperInitConfig - base_model_prefix = "custom" def __init__(self, config): super().__init__(config)