From 18a7c29ff8431193887e1065777e9cde29d46e53 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 25 Jul 2025 22:03:21 +0200 Subject: [PATCH] More robust tied weight test (#39681) * Update test_modeling_common.py * remove old ones * Update test_modeling_common.py * Update test_modeling_common.py * add * Update test_modeling_musicgen_melody.py --- tests/models/csm/test_modeling_csm.py | 37 +------------------ tests/models/dbrx/test_modeling_dbrx.py | 4 -- tests/models/mamba2/test_modeling_mamba2.py | 4 -- .../models/musicgen/test_modeling_musicgen.py | 6 +-- .../test_modeling_musicgen_melody.py | 6 +-- .../pix2struct/test_modeling_pix2struct.py | 4 -- .../test_modeling_timm_backbone.py | 4 -- tests/models/xlstm/test_modeling_xlstm.py | 4 -- tests/test_modeling_common.py | 17 ++++++--- 9 files changed, 15 insertions(+), 71 deletions(-) diff --git a/tests/models/csm/test_modeling_csm.py b/tests/models/csm/test_modeling_csm.py index 15467f5e1b..7679ab55f6 100644 --- a/tests/models/csm/test_modeling_csm.py +++ b/tests/models/csm/test_modeling_csm.py @@ -14,9 +14,7 @@ # limitations under the License. """Testing suite for the PyTorch ConversationalSpeechModel model.""" -import collections import copy -import re import unittest import pytest @@ -52,8 +50,6 @@ if is_datasets_available(): if is_torch_available(): import torch - from transformers.pytorch_utils import id_tensor_storage - class CsmModelTester: def __init__( @@ -344,38 +340,9 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u def test_model_parallel_beam_search(self): pass + @unittest.skip(reason="CSM has special embeddings that can never be tied") def test_tied_weights_keys(self): - """ - Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM). - """ - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - model_tied = model_class(config) - - ptrs = collections.defaultdict(list) - for name, tensor in model_tied.state_dict().items(): - ptrs[id_tensor_storage(tensor)].append(name) - - # These are all the pointers of shared tensors. - tied_params = [names for _, names in ptrs.items() if len(names) > 1] - - tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] - # Detect we get a hit for each key - for key in tied_weight_keys: - is_tied_key = any(re.search(key, p) for group in tied_params for p in group) - self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.") - - # Removed tied weights found from tied params -> there should only be one left after - for key in tied_weight_keys: - for i in range(len(tied_params)): - tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None] - - tied_params = [group for group in tied_params if len(group) > 1] - self.assertListEqual( - tied_params, - [], - f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.", - ) + pass def _get_custom_4d_mask_test_data(self): """ diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index e89740db61..b8b7360fa7 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -108,10 +108,6 @@ class DbrxModelTest(CausalLMModelTest, unittest.TestCase): model = DbrxModel.from_pretrained(model_name) self.assertIsNotNone(model) - @unittest.skip(reason="Dbrx models have weight tying disabled.") - def test_tied_weights_keys(self): - pass - # Offload does not work with Dbrx models because of the forward of DbrxExperts where we chunk the experts. # The issue is that the offloaded weights of the mlp layer are still on meta device (w1_chunked, v1_chunked, w2_chunked) @unittest.skip(reason="Dbrx models do not work with offload") diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index c9cec231e6..85047afb4c 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -309,10 +309,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - @unittest.skip(reason="Mamba 2 weights are not tied") - def test_tied_weights_keys(self): - pass - @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") def test_multi_gpu_data_parallel_forward(self): pass diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index e7eee02ce8..8a41c47d6f 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -781,11 +781,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_tie_model_weights(self): pass - @unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.") - def test_tied_model_weights_key_ignore(self): - pass - - @unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.") + @unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied") def test_tied_weights_keys(self): pass diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 3d7b45b643..72b20f345b 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -782,11 +782,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_tie_model_weights(self): pass - @unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.") - def test_tied_model_weights_key_ignore(self): - pass - - @unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.") + @unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied") def test_tied_weights_keys(self): pass diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 3e3bfcc1f7..2b67ec2397 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -656,10 +656,6 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) - @unittest.skip(reason="Pix2Struct doesn't use tied weights") - def test_tied_model_weights_key_ignore(self): - pass - def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: self.skipTest(reason="test_torchscript is set to False") diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 306b9d2b06..0bf79a6131 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -176,10 +176,6 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste def test_tie_model_weights(self): pass - @unittest.skip(reason="model weights aren't tied in TimmBackbone.") - def test_tied_model_weights_key_ignore(self): - pass - @unittest.skip(reason="Only checkpoints on timm can be loaded into TimmBackbone") def test_load_save_without_tied_weights(self): pass diff --git a/tests/models/xlstm/test_modeling_xlstm.py b/tests/models/xlstm/test_modeling_xlstm.py index 3ad5f67100..0e10d0999d 100644 --- a/tests/models/xlstm/test_modeling_xlstm.py +++ b/tests/models/xlstm/test_modeling_xlstm.py @@ -184,10 +184,6 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # check if it's a ones like self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) - @unittest.skip(reason="xLSTM has no tied weights") - def test_tied_weights_keys(self): - pass - @unittest.skip(reason="xLSTM cache slicing test case is an edge case") def test_generate_without_input_ids(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 53d44e60c8..38c581992b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2465,9 +2465,7 @@ class ModelTesterMixin: extra_params.pop(key, None) if not extra_params: - # In that case, we *are* on a head model, but every - # single key is not actual parameters and this is - # tested in `test_tied_model_weights_key_ignore` test. + # In that case, we *are* on a head model, but every single key is not actual parameters continue with tempfile.TemporaryDirectory() as temp_dir_name: @@ -2564,9 +2562,17 @@ class ModelTesterMixin: self.assertEqual(infos["missing_keys"], []) def test_tied_weights_keys(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + original_config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: - model_tied = model_class(copy.deepcopy(config)) + copied_config = copy.deepcopy(original_config) + copied_config.get_text_config().tie_word_embeddings = True + model_tied = model_class(copied_config) + + tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] + # If we don't find any tied weights keys, and by default we don't tie the embeddings, it's because the model + # does not tie them + if len(tied_weight_keys) == 0 and not original_config.tie_word_embeddings: + continue ptrs = collections.defaultdict(list) for name, tensor in model_tied.state_dict().items(): @@ -2575,7 +2581,6 @@ class ModelTesterMixin: # These are all the pointers of shared tensors. tied_params = [names for _, names in ptrs.items() if len(names) > 1] - tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] # Detect we get a hit for each key for key in tied_weight_keys: is_tied_key = any(re.search(key, p) for group in tied_params for p in group)