Fix from_pretrained with default base_model_prefix (#15814)
This commit is contained in:
@@ -1580,8 +1580,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
loaded_keys = list(state_dict.keys())
|
loaded_keys = list(state_dict.keys())
|
||||||
prefix = model.base_model_prefix
|
prefix = model.base_model_prefix
|
||||||
|
|
||||||
|
if len(prefix) > 0:
|
||||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||||
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
|
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
|
||||||
|
else:
|
||||||
|
has_prefix_module = False
|
||||||
|
expects_prefix_module = False
|
||||||
|
|
||||||
# key re-naming operations are never done on the keys
|
# key re-naming operations are never done on the keys
|
||||||
# that are loaded, but always on the keys of the newly initialized model
|
# that are loaded, but always on the keys of the newly initialized model
|
||||||
@@ -1669,9 +1673,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Make sure we are able to load base models as well as derived models (with heads)
|
# Make sure we are able to load base models as well as derived models (with heads)
|
||||||
start_prefix = ""
|
start_prefix = ""
|
||||||
model_to_load = model
|
model_to_load = model
|
||||||
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
||||||
start_prefix = cls.base_model_prefix + "."
|
start_prefix = cls.base_model_prefix + "."
|
||||||
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||||
model_to_load = getattr(model, cls.base_model_prefix)
|
model_to_load = getattr(model, cls.base_model_prefix)
|
||||||
if any(key in expected_keys_not_prefixed for key in loaded_keys):
|
if any(key in expected_keys_not_prefixed for key in loaded_keys):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -2105,7 +2105,10 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(tmp_dir)
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
model = NoSuperInitModel.from_pretrained(tmp_dir)
|
new_model = NoSuperInitModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from .custom_configuration import CustomConfig, NoSuperInitConfig
|
|||||||
|
|
||||||
class CustomModel(PreTrainedModel):
|
class CustomModel(PreTrainedModel):
|
||||||
config_class = CustomConfig
|
config_class = CustomConfig
|
||||||
base_model_prefix = "custom"
|
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -22,7 +21,6 @@ class CustomModel(PreTrainedModel):
|
|||||||
|
|
||||||
class NoSuperInitModel(PreTrainedModel):
|
class NoSuperInitModel(PreTrainedModel):
|
||||||
config_class = NoSuperInitConfig
|
config_class = NoSuperInitConfig
|
||||||
base_model_prefix = "custom"
|
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user