* Revert "Revert "Fix `_init_weights` for `ResNetPreTrainedModel`" (#31868)"
This reverts commit b45dd5de9c.
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
* fix
* [test_all] check
---------
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -3167,9 +3167,68 @@ class ModelTesterMixin:
|
||||
configs_no_init = _config_zero_init(config)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
|
||||
mappings = [
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
|
||||
]
|
||||
is_classication_model = any(model_class.__name__ in get_values(mapping) for mapping in mappings)
|
||||
|
||||
if not is_classication_model:
|
||||
continue
|
||||
|
||||
# TODO: ydshieh
|
||||
is_special_classes = model_class.__name__ in [
|
||||
"wav2vec2.masked_spec_embed",
|
||||
"Wav2Vec2ForSequenceClassification",
|
||||
"CLIPForImageClassification",
|
||||
"RegNetForImageClassification",
|
||||
"ResNetForImageClassification",
|
||||
"UniSpeechSatForSequenceClassification",
|
||||
"Wav2Vec2BertForSequenceClassification",
|
||||
"PvtV2ForImageClassification",
|
||||
"Wav2Vec2ConformerForSequenceClassification",
|
||||
"WavLMForSequenceClassification",
|
||||
"SwiftFormerForImageClassification",
|
||||
"SEWForSequenceClassification",
|
||||
"BitForImageClassification",
|
||||
"SEWDForSequenceClassification",
|
||||
"SiglipForImageClassification",
|
||||
"HubertForSequenceClassification",
|
||||
"Swinv2ForImageClassification",
|
||||
"Data2VecAudioForSequenceClassification",
|
||||
"UniSpeechForSequenceClassification",
|
||||
"PvtForImageClassification",
|
||||
]
|
||||
special_param_names = [
|
||||
r"^bit\.",
|
||||
r"^classifier\.weight",
|
||||
r"^classifier\.bias",
|
||||
r"^classifier\..+\.weight",
|
||||
r"^classifier\..+\.bias",
|
||||
r"^data2vec_audio\.",
|
||||
r"^dist_head\.",
|
||||
r"^head\.",
|
||||
r"^hubert\.",
|
||||
r"^pvt\.",
|
||||
r"^pvt_v2\.",
|
||||
r"^regnet\.",
|
||||
r"^resnet\.",
|
||||
r"^sew\.",
|
||||
r"^sew_d\.",
|
||||
r"^swiftformer\.",
|
||||
r"^swinv2\.",
|
||||
r"^transformers\.models\.swiftformer\.",
|
||||
r"^unispeech\.",
|
||||
r"^unispeech_sat\.",
|
||||
r"^vision_model\.",
|
||||
r"^wav2vec2\.",
|
||||
r"^wav2vec2_bert\.",
|
||||
r"^wav2vec2_conformer\.",
|
||||
r"^wavlm\.",
|
||||
]
|
||||
|
||||
with self.subTest(msg=f"Testing {model_class}"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(configs_no_init)
|
||||
@@ -3177,23 +3236,41 @@ class ModelTesterMixin:
|
||||
|
||||
# Fails when we don't set ignore_mismatched_sizes=True
|
||||
with self.assertRaises(RuntimeError):
|
||||
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||
new_model = model_class.from_pretrained(tmp_dir, num_labels=42)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||
)
|
||||
new_model = model_class.from_pretrained(tmp_dir, num_labels=42, ignore_mismatched_sizes=True)
|
||||
self.assertIn("the shapes did not match", cl.out)
|
||||
|
||||
for name, param in new_model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
param_mean = ((param.data.mean() * 1e9).round() / 1e9).item()
|
||||
if not (
|
||||
is_special_classes
|
||||
and any(len(re.findall(target, name)) > 0 for target in special_param_names)
|
||||
):
|
||||
self.assertIn(
|
||||
param_mean,
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
# Here we allow the parameters' mean to be in the range [-5.0, 5.0] instead of being
|
||||
# either `0.0` or `1.0`, because their initializations are not using
|
||||
# `config.initializer_factor` (or something similar). The purpose of this test is simply
|
||||
# to make sure they are properly initialized (to avoid very large value or even `nan`).
|
||||
self.assertGreaterEqual(
|
||||
param_mean,
|
||||
-5.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
self.assertLessEqual(
|
||||
param_mean,
|
||||
5.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
|
||||
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||
|
||||
Reference in New Issue
Block a user