Quality (#20002)
This commit is contained in:
@@ -2467,7 +2467,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
start_prefix = cls.base_model_prefix + "."
|
start_prefix = cls.base_model_prefix + "."
|
||||||
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||||
model_to_load = getattr(model, cls.base_model_prefix)
|
model_to_load = getattr(model, cls.base_model_prefix)
|
||||||
if any(key in expected_keys_not_prefixed for key in loaded_keys):
|
base_model_expected_keys = list(model_to_load.state_dict().keys())
|
||||||
|
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
|
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
|
||||||
"properly saved?"
|
"properly saved?"
|
||||||
|
|||||||
@@ -117,6 +117,36 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
from transformers.modeling_utils import shard_checkpoint
|
from transformers.modeling_utils import shard_checkpoint
|
||||||
|
|
||||||
|
# Fake pretrained models for tests
|
||||||
|
class BaseModel(PreTrainedModel):
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.linear = nn.Linear(4, 5)
|
||||||
|
self.linear_2 = nn.Linear(5, 6)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_2(self.linear(x))
|
||||||
|
|
||||||
|
class ModelWithHead(PreTrainedModel):
|
||||||
|
base_model_prefix = "base"
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.base = BaseModel(config)
|
||||||
|
# linear is a common name between Base and Head on purpose.
|
||||||
|
self.linear = nn.Linear(6, 3)
|
||||||
|
self.linear2 = nn.Linear(3, 5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear2(self.linear(self.base(x)))
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@@ -3039,6 +3069,28 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||||
self.assertTrue(torch.allclose(p1, p2))
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_base_model_to_head_model_load(self):
|
||||||
|
base_model = BaseModel(PretrainedConfig())
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
base_model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Can load a base model in a model with head
|
||||||
|
model = ModelWithHead.from_pretrained(tmp_dir)
|
||||||
|
for p1, p2 in zip(model.base.parameters(), base_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
# It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
|
||||||
|
base_state_dict = base_model.state_dict()
|
||||||
|
head_state_dict = model.state_dict()
|
||||||
|
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
|
||||||
|
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
|
||||||
|
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "The state dictionary of the model you are trying to load is corrupted."
|
||||||
|
):
|
||||||
|
_ = ModelWithHead.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user