Fix DETA save_pretrained (#30326)
* Add class_embed to tied weights for DETA * Fix test_tied_weights_keys for DETA model * Replace error raise with assert statement
This commit is contained in:
committed by
GitHub
parent
6c7335e053
commit
13b3b90ab1
@@ -1888,7 +1888,7 @@ class DetaModel(DetaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
class DetaForObjectDetection(DetaPreTrainedModel):
|
class DetaForObjectDetection(DetaPreTrainedModel):
|
||||||
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
||||||
_tied_weights_keys = [r"bbox_embed\.\d+"]
|
_tied_weights_keys = [r"bbox_embed\.\d+", r"class_embed\.\d+"]
|
||||||
# We can't initialize the model on meta device as some weights are modified during the initialization
|
# We can't initialize the model on meta device as some weights are modified during the initialization
|
||||||
_no_split_modules = None
|
_no_split_modules = None
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,10 @@
|
|||||||
""" Testing suite for the PyTorch DETA model. """
|
""" Testing suite for the PyTorch DETA model. """
|
||||||
|
|
||||||
|
|
||||||
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available
|
from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available
|
||||||
@@ -32,6 +34,8 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from transformers.pytorch_utils import id_tensor_storage
|
||||||
|
|
||||||
if is_torchvision_available():
|
if is_torchvision_available():
|
||||||
from transformers import DetaForObjectDetection, DetaModel
|
from transformers import DetaForObjectDetection, DetaModel
|
||||||
|
|
||||||
@@ -520,6 +524,43 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys
|
||||||
|
def test_tied_weights_keys(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
# We need to pass model class name to correctly initialize the config.
|
||||||
|
# If we don't pass it, the config for `DetaForObjectDetection`` will be initialized
|
||||||
|
# with `two_stage=False` and the test will fail because for that case `class_embed`
|
||||||
|
# weights are not tied.
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common(model_class_name=model_class.__name__)
|
||||||
|
config.tie_word_embeddings = True
|
||||||
|
|
||||||
|
model_tied = model_class(config)
|
||||||
|
|
||||||
|
ptrs = collections.defaultdict(list)
|
||||||
|
for name, tensor in model_tied.state_dict().items():
|
||||||
|
ptrs[id_tensor_storage(tensor)].append(name)
|
||||||
|
|
||||||
|
# These are all the pointers of shared tensors.
|
||||||
|
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||||
|
|
||||||
|
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||||
|
# Detect we get a hit for each key
|
||||||
|
for key in tied_weight_keys:
|
||||||
|
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||||
|
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
||||||
|
|
||||||
|
# Removed tied weights found from tied params -> there should only be one left after
|
||||||
|
for key in tied_weight_keys:
|
||||||
|
for i in range(len(tied_params)):
|
||||||
|
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
||||||
|
|
||||||
|
tied_params = [group for group in tied_params if len(group) > 1]
|
||||||
|
self.assertListEqual(
|
||||||
|
tied_params,
|
||||||
|
[],
|
||||||
|
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TOLERANCE = 1e-4
|
TOLERANCE = 1e-4
|
||||||
|
|
||||||
|
|||||||
@@ -2025,8 +2025,8 @@ class ModelTesterMixin:
|
|||||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||||
# Detect we get a hit for each key
|
# Detect we get a hit for each key
|
||||||
for key in tied_weight_keys:
|
for key in tied_weight_keys:
|
||||||
if not any(re.search(key, p) for group in tied_params for p in group):
|
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||||
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
|
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
||||||
|
|
||||||
# Removed tied weights found from tied params -> there should only be one left after
|
# Removed tied weights found from tied params -> there should only be one left after
|
||||||
for key in tied_weight_keys:
|
for key in tied_weight_keys:
|
||||||
|
|||||||
Reference in New Issue
Block a user