Tied params cleanup (#24211)
* First test * Add info for all models * style * Repo consistency * Fix last model and cleanup prints * Repo consistency * Use consistent function for detecting tied weights
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import gc
|
||||
import glob
|
||||
@@ -22,6 +23,7 @@ import os
|
||||
import os.path
|
||||
import pickle
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -127,6 +129,7 @@ if is_torch_available():
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
# Fake pretrained models for tests
|
||||
class BaseModel(PreTrainedModel):
|
||||
@@ -1662,6 +1665,33 @@ class ModelTesterMixin:
|
||||
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
|
||||
)
|
||||
|
||||
def test_tied_weights_keys(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.tie_word_embeddings = True
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model_tied.state_dict().items():
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# Detect we get a hit for each key
|
||||
for key in tied_weight_keys:
|
||||
if not any(re.search(key, p) for group in tied_params for p in group):
|
||||
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
|
||||
|
||||
# Removed tied weights found from tied params -> there should only be one left after
|
||||
for key in tied_weight_keys:
|
||||
for i in range(len(tied_params)):
|
||||
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
||||
|
||||
tied_params = [group for group in tied_params if len(group) > 1]
|
||||
self.assertListEqual(tied_params, [])
|
||||
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user