More robust tied weight test (#39681)

* Update test_modeling_common.py

* remove old ones

* Update test_modeling_common.py

* Update test_modeling_common.py

* add

* Update test_modeling_musicgen_melody.py
This commit is contained in:
Cyril Vallez
2025-07-25 22:03:21 +02:00
committed by GitHub
parent c3401d6fad
commit 18a7c29ff8
9 changed files with 15 additions and 71 deletions

View File

@@ -14,9 +14,7 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch ConversationalSpeechModel model.""" """Testing suite for the PyTorch ConversationalSpeechModel model."""
import collections
import copy import copy
import re
import unittest import unittest
import pytest import pytest
@@ -52,8 +50,6 @@ if is_datasets_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.pytorch_utils import id_tensor_storage
class CsmModelTester: class CsmModelTester:
def __init__( def __init__(
@@ -344,38 +340,9 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u
def test_model_parallel_beam_search(self): def test_model_parallel_beam_search(self):
pass pass
@unittest.skip(reason="CSM has special embeddings that can never be tied")
def test_tied_weights_keys(self): def test_tied_weights_keys(self):
""" pass
Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM).
"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
def _get_custom_4d_mask_test_data(self): def _get_custom_4d_mask_test_data(self):
""" """

View File

@@ -108,10 +108,6 @@ class DbrxModelTest(CausalLMModelTest, unittest.TestCase):
model = DbrxModel.from_pretrained(model_name) model = DbrxModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip(reason="Dbrx models have weight tying disabled.")
def test_tied_weights_keys(self):
pass
# Offload does not work with Dbrx models because of the forward of DbrxExperts where we chunk the experts. # Offload does not work with Dbrx models because of the forward of DbrxExperts where we chunk the experts.
# The issue is that the offloaded weights of the mlp layer are still on meta device (w1_chunked, v1_chunked, w2_chunked) # The issue is that the offloaded weights of the mlp layer are still on meta device (w1_chunked, v1_chunked, w2_chunked)
@unittest.skip(reason="Dbrx models do not work with offload") @unittest.skip(reason="Dbrx models do not work with offload")

View File

@@ -309,10 +309,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
@unittest.skip(reason="Mamba 2 weights are not tied")
def test_tied_weights_keys(self):
pass
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
pass pass

View File

@@ -781,11 +781,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_tie_model_weights(self): def test_tie_model_weights(self):
pass pass
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.") @unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied")
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
def test_tied_weights_keys(self): def test_tied_weights_keys(self):
pass pass

View File

@@ -782,11 +782,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_tie_model_weights(self): def test_tie_model_weights(self):
pass pass
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.") @unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied")
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
def test_tied_weights_keys(self): def test_tied_weights_keys(self):
pass pass

View File

@@ -656,10 +656,6 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# Check that the model can still do a forward pass successfully (every parameter should be resized) # Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class)) model(**self._prepare_for_class(inputs_dict, model_class))
@unittest.skip(reason="Pix2Struct doesn't use tied weights")
def test_tied_model_weights_key_ignore(self):
pass
def _create_and_check_torchscript(self, config, inputs_dict): def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript: if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False") self.skipTest(reason="test_torchscript is set to False")

View File

@@ -176,10 +176,6 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
def test_tie_model_weights(self): def test_tie_model_weights(self):
pass pass
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip(reason="Only checkpoints on timm can be loaded into TimmBackbone") @unittest.skip(reason="Only checkpoints on timm can be loaded into TimmBackbone")
def test_load_save_without_tied_weights(self): def test_load_save_without_tied_weights(self):
pass pass

View File

@@ -184,10 +184,6 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# check if it's a ones like # check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
@unittest.skip(reason="xLSTM has no tied weights")
def test_tied_weights_keys(self):
pass
@unittest.skip(reason="xLSTM cache slicing test case is an edge case") @unittest.skip(reason="xLSTM cache slicing test case is an edge case")
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
pass pass

View File

@@ -2465,9 +2465,7 @@ class ModelTesterMixin:
extra_params.pop(key, None) extra_params.pop(key, None)
if not extra_params: if not extra_params:
# In that case, we *are* on a head model, but every # In that case, we *are* on a head model, but every single key is not actual parameters
# single key is not actual parameters and this is
# tested in `test_tied_model_weights_key_ignore` test.
continue continue
with tempfile.TemporaryDirectory() as temp_dir_name: with tempfile.TemporaryDirectory() as temp_dir_name:
@@ -2564,9 +2562,17 @@ class ModelTesterMixin:
self.assertEqual(infos["missing_keys"], []) self.assertEqual(infos["missing_keys"], [])
def test_tied_weights_keys(self): def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() original_config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model_tied = model_class(copy.deepcopy(config)) copied_config = copy.deepcopy(original_config)
copied_config.get_text_config().tie_word_embeddings = True
model_tied = model_class(copied_config)
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# If we don't find any tied weights keys, and by default we don't tie the embeddings, it's because the model
# does not tie them
if len(tied_weight_keys) == 0 and not original_config.tie_word_embeddings:
continue
ptrs = collections.defaultdict(list) ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items(): for name, tensor in model_tied.state_dict().items():
@@ -2575,7 +2581,6 @@ class ModelTesterMixin:
# These are all the pointers of shared tensors. # These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1] tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key # Detect we get a hit for each key
for key in tied_weight_keys: for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group) is_tied_key = any(re.search(key, p) for group in tied_params for p in group)