More robust tied weight test (#39681)

* Update test_modeling_common.py

* remove old ones

* Update test_modeling_common.py

* Update test_modeling_common.py

* add

* Update test_modeling_musicgen_melody.py
This commit is contained in:
Cyril Vallez
2025-07-25 22:03:21 +02:00
committed by GitHub
parent c3401d6fad
commit 18a7c29ff8
9 changed files with 15 additions and 71 deletions

View File

@@ -14,9 +14,7 @@
# limitations under the License.
"""Testing suite for the PyTorch ConversationalSpeechModel model."""
import collections
import copy
import re
import unittest
import pytest
@@ -52,8 +50,6 @@ if is_datasets_available():
if is_torch_available():
import torch
from transformers.pytorch_utils import id_tensor_storage
class CsmModelTester:
def __init__(
@@ -344,38 +340,9 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u
def test_model_parallel_beam_search(self):
pass
@unittest.skip(reason="CSM has special embeddings that can never be tied")
def test_tied_weights_keys(self):
"""
Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM).
"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
pass
def _get_custom_4d_mask_test_data(self):
"""

View File

@@ -108,10 +108,6 @@ class DbrxModelTest(CausalLMModelTest, unittest.TestCase):
model = DbrxModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip(reason="Dbrx models have weight tying disabled.")
def test_tied_weights_keys(self):
pass
# Offload does not work with Dbrx models because of the forward of DbrxExperts where we chunk the experts.
# The issue is that the offloaded weights of the mlp layer are still on meta device (w1_chunked, v1_chunked, w2_chunked)
@unittest.skip(reason="Dbrx models do not work with offload")

View File

@@ -309,10 +309,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip(reason="Mamba 2 weights are not tied")
def test_tied_weights_keys(self):
pass
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
def test_multi_gpu_data_parallel_forward(self):
pass

View File

@@ -781,11 +781,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_tie_model_weights(self):
pass
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied")
def test_tied_weights_keys(self):
pass

View File

@@ -782,11 +782,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_tie_model_weights(self):
pass
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied")
def test_tied_weights_keys(self):
pass

View File

@@ -656,10 +656,6 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
@unittest.skip(reason="Pix2Struct doesn't use tied weights")
def test_tied_model_weights_key_ignore(self):
pass
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")

View File

@@ -176,10 +176,6 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
def test_tie_model_weights(self):
pass
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip(reason="Only checkpoints on timm can be loaded into TimmBackbone")
def test_load_save_without_tied_weights(self):
pass

View File

@@ -184,10 +184,6 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
@unittest.skip(reason="xLSTM has no tied weights")
def test_tied_weights_keys(self):
pass
@unittest.skip(reason="xLSTM cache slicing test case is an edge case")
def test_generate_without_input_ids(self):
pass

View File

@@ -2465,9 +2465,7 @@ class ModelTesterMixin:
extra_params.pop(key, None)
if not extra_params:
# In that case, we *are* on a head model, but every
# single key is not actual parameters and this is
# tested in `test_tied_model_weights_key_ignore` test.
# In that case, we *are* on a head model, but every single key is not actual parameters
continue
with tempfile.TemporaryDirectory() as temp_dir_name:
@@ -2564,9 +2562,17 @@ class ModelTesterMixin:
self.assertEqual(infos["missing_keys"], [])
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
original_config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(copy.deepcopy(config))
copied_config = copy.deepcopy(original_config)
copied_config.get_text_config().tie_word_embeddings = True
model_tied = model_class(copied_config)
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# If we don't find any tied weights keys, and by default we don't tie the embeddings, it's because the model
# does not tie them
if len(tied_weight_keys) == 0 and not original_config.tie_word_embeddings:
continue
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
@@ -2575,7 +2581,6 @@ class ModelTesterMixin:
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)