Fix from_pretrained with default base_model_prefix (#15814)

This commit is contained in:
Sylvain Gugger
2022-02-24 11:43:51 +01:00
committed by GitHub
parent 7f921bcf47
commit d1fcc90abf
3 changed files with 12 additions and 7 deletions

View File

@@ -1580,8 +1580,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loaded_keys = list(state_dict.keys()) loaded_keys = list(state_dict.keys())
prefix = model.base_model_prefix prefix = model.base_model_prefix
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
# key re-naming operations are never done on the keys # key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model # that are loaded, but always on the keys of the newly initialized model
@@ -1669,9 +1673,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Make sure we are able to load base models as well as derived models (with heads) # Make sure we are able to load base models as well as derived models (with heads)
start_prefix = "" start_prefix = ""
model_to_load = model model_to_load = model
if not hasattr(model, cls.base_model_prefix) and has_prefix_module: if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "." start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module: if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix) model_to_load = getattr(model, cls.base_model_prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys): if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError( raise ValueError(

View File

@@ -2105,7 +2105,10 @@ class ModelUtilsTest(TestCasePlus):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
model = NoSuperInitModel.from_pretrained(tmp_dir) new_model = NoSuperInitModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_torch @require_torch

View File

@@ -7,7 +7,6 @@ from .custom_configuration import CustomConfig, NoSuperInitConfig
class CustomModel(PreTrainedModel): class CustomModel(PreTrainedModel):
config_class = CustomConfig config_class = CustomConfig
base_model_prefix = "custom"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -22,7 +21,6 @@ class CustomModel(PreTrainedModel):
class NoSuperInitModel(PreTrainedModel): class NoSuperInitModel(PreTrainedModel):
config_class = NoSuperInitConfig config_class = NoSuperInitConfig
base_model_prefix = "custom"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)