More robust tied weight test (#39681)
* Update test_modeling_common.py * remove old ones * Update test_modeling_common.py * Update test_modeling_common.py * add * Update test_modeling_musicgen_melody.py
This commit is contained in:
@@ -2465,9 +2465,7 @@ class ModelTesterMixin:
|
||||
extra_params.pop(key, None)
|
||||
|
||||
if not extra_params:
|
||||
# In that case, we *are* on a head model, but every
|
||||
# single key is not actual parameters and this is
|
||||
# tested in `test_tied_model_weights_key_ignore` test.
|
||||
# In that case, we *are* on a head model, but every single key is not actual parameters
|
||||
continue
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
@@ -2564,9 +2562,17 @@ class ModelTesterMixin:
|
||||
self.assertEqual(infos["missing_keys"], [])
|
||||
|
||||
def test_tied_weights_keys(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(copy.deepcopy(config))
|
||||
copied_config = copy.deepcopy(original_config)
|
||||
copied_config.get_text_config().tie_word_embeddings = True
|
||||
model_tied = model_class(copied_config)
|
||||
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# If we don't find any tied weights keys, and by default we don't tie the embeddings, it's because the model
|
||||
# does not tie them
|
||||
if len(tied_weight_keys) == 0 and not original_config.tie_word_embeddings:
|
||||
continue
|
||||
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model_tied.state_dict().items():
|
||||
@@ -2575,7 +2581,6 @@ class ModelTesterMixin:
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# Detect we get a hit for each key
|
||||
for key in tied_weight_keys:
|
||||
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||
|
||||
Reference in New Issue
Block a user