More robust tied weight test (#39681)
* Update test_modeling_common.py * remove old ones * Update test_modeling_common.py * Update test_modeling_common.py * add * Update test_modeling_musicgen_melody.py
This commit is contained in:
@@ -14,9 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch ConversationalSpeechModel model."""
|
"""Testing suite for the PyTorch ConversationalSpeechModel model."""
|
||||||
|
|
||||||
import collections
|
|
||||||
import copy
|
import copy
|
||||||
import re
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -52,8 +50,6 @@ if is_datasets_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.pytorch_utils import id_tensor_storage
|
|
||||||
|
|
||||||
|
|
||||||
class CsmModelTester:
|
class CsmModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -344,38 +340,9 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u
|
|||||||
def test_model_parallel_beam_search(self):
|
def test_model_parallel_beam_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="CSM has special embeddings that can never be tied")
|
||||||
def test_tied_weights_keys(self):
|
def test_tied_weights_keys(self):
|
||||||
"""
|
pass
|
||||||
Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM).
|
|
||||||
"""
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model_tied = model_class(config)
|
|
||||||
|
|
||||||
ptrs = collections.defaultdict(list)
|
|
||||||
for name, tensor in model_tied.state_dict().items():
|
|
||||||
ptrs[id_tensor_storage(tensor)].append(name)
|
|
||||||
|
|
||||||
# These are all the pointers of shared tensors.
|
|
||||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
|
||||||
|
|
||||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
|
||||||
# Detect we get a hit for each key
|
|
||||||
for key in tied_weight_keys:
|
|
||||||
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
|
||||||
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
|
||||||
|
|
||||||
# Removed tied weights found from tied params -> there should only be one left after
|
|
||||||
for key in tied_weight_keys:
|
|
||||||
for i in range(len(tied_params)):
|
|
||||||
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
|
||||||
|
|
||||||
tied_params = [group for group in tied_params if len(group) > 1]
|
|
||||||
self.assertListEqual(
|
|
||||||
tied_params,
|
|
||||||
[],
|
|
||||||
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_custom_4d_mask_test_data(self):
|
def _get_custom_4d_mask_test_data(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -108,10 +108,6 @@ class DbrxModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
model = DbrxModel.from_pretrained(model_name)
|
model = DbrxModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@unittest.skip(reason="Dbrx models have weight tying disabled.")
|
|
||||||
def test_tied_weights_keys(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Offload does not work with Dbrx models because of the forward of DbrxExperts where we chunk the experts.
|
# Offload does not work with Dbrx models because of the forward of DbrxExperts where we chunk the experts.
|
||||||
# The issue is that the offloaded weights of the mlp layer are still on meta device (w1_chunked, v1_chunked, w2_chunked)
|
# The issue is that the offloaded weights of the mlp layer are still on meta device (w1_chunked, v1_chunked, w2_chunked)
|
||||||
@unittest.skip(reason="Dbrx models do not work with offload")
|
@unittest.skip(reason="Dbrx models do not work with offload")
|
||||||
|
|||||||
@@ -309,10 +309,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Mamba 2 weights are not tied")
|
|
||||||
def test_tied_weights_keys(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||||
def test_multi_gpu_data_parallel_forward(self):
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -781,11 +781,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
|
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied")
|
||||||
def test_tied_model_weights_key_ignore(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
|
|
||||||
def test_tied_weights_keys(self):
|
def test_tied_weights_keys(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -782,11 +782,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
|
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied")
|
||||||
def test_tied_model_weights_key_ignore(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="MusicGen has multiple inputs embeds and lm heads that should not be tied.")
|
|
||||||
def test_tied_weights_keys(self):
|
def test_tied_weights_keys(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -656,10 +656,6 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
@unittest.skip(reason="Pix2Struct doesn't use tied weights")
|
|
||||||
def test_tied_model_weights_key_ignore(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||||
if not self.test_torchscript:
|
if not self.test_torchscript:
|
||||||
self.skipTest(reason="test_torchscript is set to False")
|
self.skipTest(reason="test_torchscript is set to False")
|
||||||
|
|||||||
@@ -176,10 +176,6 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
|
|||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
|
|
||||||
def test_tied_model_weights_key_ignore(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Only checkpoints on timm can be loaded into TimmBackbone")
|
@unittest.skip(reason="Only checkpoints on timm can be loaded into TimmBackbone")
|
||||||
def test_load_save_without_tied_weights(self):
|
def test_load_save_without_tied_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -184,10 +184,6 @@ class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# check if it's a ones like
|
# check if it's a ones like
|
||||||
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
||||||
|
|
||||||
@unittest.skip(reason="xLSTM has no tied weights")
|
|
||||||
def test_tied_weights_keys(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="xLSTM cache slicing test case is an edge case")
|
@unittest.skip(reason="xLSTM cache slicing test case is an edge case")
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -2465,9 +2465,7 @@ class ModelTesterMixin:
|
|||||||
extra_params.pop(key, None)
|
extra_params.pop(key, None)
|
||||||
|
|
||||||
if not extra_params:
|
if not extra_params:
|
||||||
# In that case, we *are* on a head model, but every
|
# In that case, we *are* on a head model, but every single key is not actual parameters
|
||||||
# single key is not actual parameters and this is
|
|
||||||
# tested in `test_tied_model_weights_key_ignore` test.
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||||
@@ -2564,9 +2562,17 @@ class ModelTesterMixin:
|
|||||||
self.assertEqual(infos["missing_keys"], [])
|
self.assertEqual(infos["missing_keys"], [])
|
||||||
|
|
||||||
def test_tied_weights_keys(self):
|
def test_tied_weights_keys(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
original_config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model_tied = model_class(copy.deepcopy(config))
|
copied_config = copy.deepcopy(original_config)
|
||||||
|
copied_config.get_text_config().tie_word_embeddings = True
|
||||||
|
model_tied = model_class(copied_config)
|
||||||
|
|
||||||
|
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||||
|
# If we don't find any tied weights keys, and by default we don't tie the embeddings, it's because the model
|
||||||
|
# does not tie them
|
||||||
|
if len(tied_weight_keys) == 0 and not original_config.tie_word_embeddings:
|
||||||
|
continue
|
||||||
|
|
||||||
ptrs = collections.defaultdict(list)
|
ptrs = collections.defaultdict(list)
|
||||||
for name, tensor in model_tied.state_dict().items():
|
for name, tensor in model_tied.state_dict().items():
|
||||||
@@ -2575,7 +2581,6 @@ class ModelTesterMixin:
|
|||||||
# These are all the pointers of shared tensors.
|
# These are all the pointers of shared tensors.
|
||||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||||
|
|
||||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
|
||||||
# Detect we get a hit for each key
|
# Detect we get a hit for each key
|
||||||
for key in tied_weight_keys:
|
for key in tied_weight_keys:
|
||||||
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||||
|
|||||||
Reference in New Issue
Block a user