Fix DETA save_pretrained (#30326)

* Add class_embed to tied weights for DETA

* Fix test_tied_weights_keys for DETA model

* Replace error raise with assert statement
This commit is contained in:
Pavel Iakubovskii
2024-04-22 17:11:13 +01:00
committed by GitHub
parent 6c7335e053
commit 13b3b90ab1
3 changed files with 44 additions and 3 deletions

View File

@@ -1888,7 +1888,7 @@ class DetaModel(DetaPreTrainedModel):
) )
class DetaForObjectDetection(DetaPreTrainedModel): class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required # When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.\d+"] _tied_weights_keys = [r"bbox_embed\.\d+", r"class_embed\.\d+"]
# We can't initialize the model on meta device as some weights are modified during the initialization # We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None _no_split_modules = None

View File

@@ -15,8 +15,10 @@
""" Testing suite for the PyTorch DETA model. """ """ Testing suite for the PyTorch DETA model. """
import collections
import inspect import inspect
import math import math
import re
import unittest import unittest
from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available
@@ -32,6 +34,8 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.pytorch_utils import id_tensor_storage
if is_torchvision_available(): if is_torchvision_available():
from transformers import DetaForObjectDetection, DetaModel from transformers import DetaForObjectDetection, DetaModel
@@ -520,6 +524,43 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys
def test_tied_weights_keys(self):
for model_class in self.all_model_classes:
# We need to pass model class name to correctly initialize the config.
# If we don't pass it, the config for `DetaForObjectDetection`` will be initialized
# with `two_stage=False` and the test will fail because for that case `class_embed`
# weights are not tied.
config, _ = self.model_tester.prepare_config_and_inputs_for_common(model_class_name=model_class.__name__)
config.tie_word_embeddings = True
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
TOLERANCE = 1e-4 TOLERANCE = 1e-4

View File

@@ -2025,8 +2025,8 @@ class ModelTesterMixin:
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key # Detect we get a hit for each key
for key in tied_weight_keys: for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group): is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
raise ValueError(f"{key} is not a tied weight key for {model_class}.") self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after # Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys: for key in tied_weight_keys: