This commit is contained in:
Sylvain Gugger
2022-11-02 09:53:37 -04:00
committed by GitHub
parent c6c9db3d0c
commit 49b77b89ea
2 changed files with 54 additions and 1 deletions

View File

@@ -117,6 +117,36 @@ if is_torch_available():
)
from transformers.modeling_utils import shard_checkpoint
# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config):
super().__init__(config)
self.linear = nn.Linear(4, 5)
self.linear_2 = nn.Linear(5, 6)
def forward(self, x):
return self.linear_2(self.linear(x))
class ModelWithHead(PreTrainedModel):
base_model_prefix = "base"
config_class = PretrainedConfig
def _init_weights(self, module):
pass
def __init__(self, config):
super().__init__(config)
self.base = BaseModel(config)
# linear is a common name between Base and Head on purpose.
self.linear = nn.Linear(6, 3)
self.linear2 = nn.Linear(3, 5)
def forward(self, x):
return self.linear2(self.linear(self.base(x)))
if is_tf_available():
import tensorflow as tf
@@ -3039,6 +3069,28 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_base_model_to_head_model_load(self):
base_model = BaseModel(PretrainedConfig())
with tempfile.TemporaryDirectory() as tmp_dir:
base_model.save_pretrained(tmp_dir)
# Can load a base model in a model with head
model = ModelWithHead.from_pretrained(tmp_dir)
for p1, p2 in zip(model.base.parameters(), base_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
# It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
base_state_dict = base_model.state_dict()
head_state_dict = model.state_dict()
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
with self.assertRaisesRegex(
ValueError, "The state dictionary of the model you are trying to load is corrupted."
):
_ = ModelWithHead.from_pretrained(tmp_dir)
@require_torch
@is_staging_test