Tied weights load (#24310)

* Use tied weight keys

* More

* Fix tied weight missing warning

* Only give info on unexpected keys with different classes

* Deal with empty archs

* Fix tests

* Refine test
This commit is contained in:
Sylvain Gugger
2023-06-16 10:55:42 -04:00
committed by GitHub
parent 61ffdeba38
commit 096f2cf126
2 changed files with 96 additions and 20 deletions

View File

@@ -82,16 +82,31 @@ if is_torch_available():
# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
base_model_prefix = "base"
config_class = PretrainedConfig
def __init__(self, config):
super().__init__(config)
self.linear = nn.Linear(4, 5)
self.linear_2 = nn.Linear(5, 6)
self.linear = nn.Linear(5, 5)
self.linear_2 = nn.Linear(5, 5)
def forward(self, x):
return self.linear_2(self.linear(x))
class BaseModelWithTiedWeights(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config):
super().__init__(config)
self.linear = nn.Linear(5, 5)
self.linear_2 = nn.Linear(5, 5)
def forward(self, x):
return self.linear_2(self.linear(x))
def tie_weights(self):
self.linear_2.weight = self.linear.weight
class ModelWithHead(PreTrainedModel):
base_model_prefix = "base"
config_class = PretrainedConfig
@@ -103,12 +118,30 @@ if is_torch_available():
super().__init__(config)
self.base = BaseModel(config)
# linear is a common name between Base and Head on purpose.
self.linear = nn.Linear(6, 3)
self.linear2 = nn.Linear(3, 5)
self.linear = nn.Linear(5, 5)
self.linear2 = nn.Linear(5, 5)
def forward(self, x):
return self.linear2(self.linear(self.base(x)))
class ModelWithHeadAndTiedWeights(PreTrainedModel):
base_model_prefix = "base"
config_class = PretrainedConfig
def _init_weights(self, module):
pass
def __init__(self, config):
super().__init__(config)
self.base = BaseModel(config)
self.decoder = nn.Linear(5, 5)
def forward(self, x):
return self.decoder(self.base(x))
def tie_weights(self):
self.decoder.weight = self.base.linear.weight
TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
@@ -857,6 +890,54 @@ class ModelUtilsTest(TestCasePlus):
):
_ = ModelWithHead.from_pretrained(tmp_dir)
def test_tied_weights_reload(self):
# Base
model = BaseModelWithTiedWeights(PretrainedConfig())
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = BaseModelWithTiedWeights.from_pretrained(tmp_dir)
self.assertIs(new_model.linear.weight, new_model.linear_2.weight)
state_dict = model.state_dict()
# Remove tied weight from state_dict -> model should load with no complain of missing keys
del state_dict["linear_2.weight"]
torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
self.assertListEqual(load_info["missing_keys"], [])
self.assertIs(new_model.linear.weight, new_model.linear_2.weight)
# With head
model.save_pretrained(tmp_dir)
new_model, load_info = ModelWithHeadAndTiedWeights.from_pretrained(tmp_dir, output_loading_info=True)
self.assertIs(new_model.base.linear.weight, new_model.decoder.weight)
# Should only complain about the missing bias
self.assertListEqual(load_info["missing_keys"], ["decoder.bias"])
def test_unexpected_keys_warnings(self):
model = ModelWithHead(PretrainedConfig())
logger = logging.get_logger("transformers.modeling_utils")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
# Loading the model with a new class, we don't get a warning for unexpected weights, just an info
with CaptureLogger(logger) as cl:
_, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True)
self.assertNotIn("were not used when initializing ModelWithHead", cl.out)
self.assertEqual(
set(loading_info["unexpected_keys"]),
{"linear.weight", "linear.bias", "linear2.weight", "linear2.bias"},
)
# Loading the model with the same class, we do get a warning for unexpected weights
state_dict = model.state_dict()
state_dict["added_key"] = state_dict["linear.weight"]
torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
with CaptureLogger(logger) as cl:
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
self.assertEqual(loading_info["unexpected_keys"], ["added_key"])
@require_torch_gpu
@slow
def test_pretrained_low_mem_new_config(self):