More robust tied weight test (#39681)

* Update test_modeling_common.py

* remove old ones

* Update test_modeling_common.py

* Update test_modeling_common.py

* add

* Update test_modeling_musicgen_melody.py
This commit is contained in:
Cyril Vallez
2025-07-25 22:03:21 +02:00
committed by GitHub
parent c3401d6fad
commit 18a7c29ff8
9 changed files with 15 additions and 71 deletions

View File

@@ -2465,9 +2465,7 @@ class ModelTesterMixin:
extra_params.pop(key, None)
if not extra_params:
# In that case, we *are* on a head model, but every
# single key is not actual parameters and this is
# tested in `test_tied_model_weights_key_ignore` test.
# In that case, we *are* on a head model, but every single key is not actual parameters
continue
with tempfile.TemporaryDirectory() as temp_dir_name:
@@ -2564,9 +2562,17 @@ class ModelTesterMixin:
self.assertEqual(infos["missing_keys"], [])
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
original_config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(copy.deepcopy(config))
copied_config = copy.deepcopy(original_config)
copied_config.get_text_config().tie_word_embeddings = True
model_tied = model_class(copied_config)
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# If we don't find any tied weights keys, and by default we don't tie the embeddings, it's because the model
# does not tie them
if len(tied_weight_keys) == 0 and not original_config.tie_word_embeddings:
continue
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
@@ -2575,7 +2581,6 @@ class ModelTesterMixin:
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)