Quality (#20002)
This commit is contained in:
@@ -117,6 +117,36 @@ if is_torch_available():
|
||||
)
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
|
||||
# Fake pretrained models for tests
|
||||
class BaseModel(PreTrainedModel):
|
||||
config_class = PretrainedConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = nn.Linear(4, 5)
|
||||
self.linear_2 = nn.Linear(5, 6)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_2(self.linear(x))
|
||||
|
||||
class ModelWithHead(PreTrainedModel):
|
||||
base_model_prefix = "base"
|
||||
config_class = PretrainedConfig
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.base = BaseModel(config)
|
||||
# linear is a common name between Base and Head on purpose.
|
||||
self.linear = nn.Linear(6, 3)
|
||||
self.linear2 = nn.Linear(3, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.linear(self.base(x)))
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -3039,6 +3069,28 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_base_model_to_head_model_load(self):
|
||||
base_model = BaseModel(PretrainedConfig())
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
base_model.save_pretrained(tmp_dir)
|
||||
|
||||
# Can load a base model in a model with head
|
||||
model = ModelWithHead.from_pretrained(tmp_dir)
|
||||
for p1, p2 in zip(model.base.parameters(), base_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
# It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
|
||||
base_state_dict = base_model.state_dict()
|
||||
head_state_dict = model.state_dict()
|
||||
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
|
||||
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
|
||||
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The state dictionary of the model you are trying to load is corrupted."
|
||||
):
|
||||
_ = ModelWithHead.from_pretrained(tmp_dir)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user