More robust tied weight test (#39681)
* Update test_modeling_common.py * remove old ones * Update test_modeling_common.py * Update test_modeling_common.py * add * Update test_modeling_musicgen_melody.py
This commit is contained in:
@@ -14,9 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch ConversationalSpeechModel model."""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -52,8 +50,6 @@ if is_datasets_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
class CsmModelTester:
|
||||
def __init__(
|
||||
@@ -344,38 +340,9 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u
|
||||
def test_model_parallel_beam_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CSM has special embeddings that can never be tied")
|
||||
def test_tied_weights_keys(self):
|
||||
"""
|
||||
Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM).
|
||||
"""
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model_tied.state_dict().items():
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# Detect we get a hit for each key
|
||||
for key in tied_weight_keys:
|
||||
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
||||
|
||||
# Removed tied weights found from tied params -> there should only be one left after
|
||||
for key in tied_weight_keys:
|
||||
for i in range(len(tied_params)):
|
||||
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
||||
|
||||
tied_params = [group for group in tied_params if len(group) > 1]
|
||||
self.assertListEqual(
|
||||
tied_params,
|
||||
[],
|
||||
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
|
||||
)
|
||||
pass
|
||||
|
||||
def _get_custom_4d_mask_test_data(self):
|
||||
"""
|
||||
|
||||
@@ -108,10 +108,6 @@ class DbrxModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
model = DbrxModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Dbrx models have weight tying disabled.")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
# Offload does not work with Dbrx models because of the forward of DbrxExperts where we chunk the experts.
|
||||
# The issue is that the offloaded weights of the mlp layer are still on meta device (w1_chunked, v1_chunked, w2_chunked)
|
||||
@unittest.skip(reason="Dbrx models do not work with offload")
|
||||
|
||||
@@ -309,10 +309,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Mamba 2 weights are not tied")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@@ -781,11 +781,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
|
||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -782,11 +782,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
|
||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -656,10 +656,6 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@unittest.skip(reason="Pix2Struct doesn't use tied weights")
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
pass
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
self.skipTest(reason="test_torchscript is set to False")
|
||||
|
||||
@@ -176,10 +176,6 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Only checkpoints on timm can be loaded into TimmBackbone")
|
||||
def test_load_save_without_tied_weights(self):
|
||||
pass
|
||||
|
||||
@@ -184,10 +184,6 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# check if it's a ones like
|
||||
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
||||
|
||||
@unittest.skip(reason="xLSTM has no tied weights")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="xLSTM cache slicing test case is an edge case")
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@@ -2465,9 +2465,7 @@ class ModelTesterMixin:
|
||||
extra_params.pop(key, None)
|
||||
|
||||
if not extra_params:
|
||||
# In that case, we *are* on a head model, but every
|
||||
# single key is not actual parameters and this is
|
||||
# tested in `test_tied_model_weights_key_ignore` test.
|
||||
# In that case, we *are* on a head model, but every single key is not actual parameters
|
||||
continue
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
@@ -2564,9 +2562,17 @@ class ModelTesterMixin:
|
||||
self.assertEqual(infos["missing_keys"], [])
|
||||
|
||||
def test_tied_weights_keys(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(copy.deepcopy(config))
|
||||
copied_config = copy.deepcopy(original_config)
|
||||
copied_config.get_text_config().tie_word_embeddings = True
|
||||
model_tied = model_class(copied_config)
|
||||
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# If we don't find any tied weights keys, and by default we don't tie the embeddings, it's because the model
|
||||
# does not tie them
|
||||
if len(tied_weight_keys) == 0 and not original_config.tie_word_embeddings:
|
||||
continue
|
||||
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model_tied.state_dict().items():
|
||||
@@ -2575,7 +2581,6 @@ class ModelTesterMixin:
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# Detect we get a hit for each key
|
||||
for key in tied_weight_keys:
|
||||
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||
|
||||
Reference in New Issue
Block a user