Revert "Fix _init_weights for ResNetPreTrainedModel" (#31868)
Revert "Fix `_init_weights` for `ResNetPreTrainedModel` (#31851)"
This reverts commit 4c8149d643.
This commit is contained in:
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch RegNet model."""
|
"""PyTorch RegNet model."""
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -285,13 +284,6 @@ class RegNetPreTrainedModel(PreTrainedModel):
|
|||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
if isinstance(module, nn.Conv2d):
|
if isinstance(module, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||||
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
|
||||||
elif isinstance(module, nn.Linear):
|
|
||||||
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
||||||
if module.bias is not None:
|
|
||||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
|
||||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
||||||
nn.init.uniform_(module.bias, -bound, bound)
|
|
||||||
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
nn.init.constant_(module.weight, 1)
|
nn.init.constant_(module.weight, 1)
|
||||||
nn.init.constant_(module.bias, 0)
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch ResNet model."""
|
"""PyTorch ResNet model."""
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -275,13 +274,6 @@ class ResNetPreTrainedModel(PreTrainedModel):
|
|||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
if isinstance(module, nn.Conv2d):
|
if isinstance(module, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||||
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
|
||||||
elif isinstance(module, nn.Linear):
|
|
||||||
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
||||||
if module.bias is not None:
|
|
||||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
|
||||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
||||||
nn.init.uniform_(module.bias, -bound, bound)
|
|
||||||
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
nn.init.constant_(module.weight, 1)
|
nn.init.constant_(module.weight, 1)
|
||||||
nn.init.constant_(module.bias, 0)
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ PyTorch RTDetr specific ResNet model. The main difference between hugginface Res
|
|||||||
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
|
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@@ -324,13 +323,6 @@ class RTDetrResNetPreTrainedModel(PreTrainedModel):
|
|||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
if isinstance(module, nn.Conv2d):
|
if isinstance(module, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||||
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
|
||||||
elif isinstance(module, nn.Linear):
|
|
||||||
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
||||||
if module.bias is not None:
|
|
||||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
|
||||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
||||||
nn.init.uniform_(module.bias, -bound, bound)
|
|
||||||
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
nn.init.constant_(module.weight, 1)
|
nn.init.constant_(module.weight, 1)
|
||||||
nn.init.constant_(module.bias, 0)
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|||||||
@@ -3167,47 +3167,9 @@ class ModelTesterMixin:
|
|||||||
configs_no_init = _config_zero_init(config)
|
configs_no_init = _config_zero_init(config)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
mappings = [
|
if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
|
||||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
|
||||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
|
|
||||||
]
|
|
||||||
is_classication_model = any(model_class.__name__ in get_values(mapping) for mapping in mappings)
|
|
||||||
|
|
||||||
if not is_classication_model:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# TODO: ydshieh
|
|
||||||
is_special_classes = model_class.__name__ in [
|
|
||||||
"wav2vec2.masked_spec_embed",
|
|
||||||
"Wav2Vec2ForSequenceClassification",
|
|
||||||
"CLIPForImageClassification",
|
|
||||||
"RegNetForImageClassification",
|
|
||||||
"ResNetForImageClassification",
|
|
||||||
]
|
|
||||||
special_param_names = [
|
|
||||||
r"wav2vec2\.masked_spec_embed",
|
|
||||||
r"wav2vec2\.feature_extractor\.conv_layers\..+\.conv\.weight",
|
|
||||||
r"wav2vec2\.feature_projection\.projection\.weight",
|
|
||||||
r"wav2vec2\.feature_projection\.projection\.bias",
|
|
||||||
r"wav2vec2\.encoder\.pos_conv_embed\.conv\.parametrizations\.weight\.original.",
|
|
||||||
r"classifier\.weight",
|
|
||||||
r"regnet\.embedder\.embedder\.convolution\.weight",
|
|
||||||
r"regnet\.encoder\.stages\..+\.layers\..+\.layer\..+\.convolution\.weight",
|
|
||||||
r"regnet\.encoder\.stages\..+\.layers\..+\.shortcut\.convolution\.weight",
|
|
||||||
r"regnet\.encoder\.stages\..+\.layers\..+\.layer\..+\.attention\..+\.weight",
|
|
||||||
r"regnet\.encoder\.stages\..+\.layers\..+\.layer\..+\.attention\..+\.bias",
|
|
||||||
r"classifier\..+\.weight",
|
|
||||||
r"classifier\..+\.bias",
|
|
||||||
r"resnet\.embedder\.embedder\.convolution\.weight",
|
|
||||||
r"resnet\.encoder\.stages\..+\.layers\..+\.shortcut\.convolution\.weight",
|
|
||||||
r"resnet\.encoder\.stages\..+\.layers\..+\.layer\..+\.convolution\.weight",
|
|
||||||
r"resnet\.encoder\.stages\..+\.layers\..+\.shortcut\.convolution\.weight",
|
|
||||||
r"resnet\.encoder\.stages\..+\.layers\..+\.layer\..+\.attention\..+\.weight",
|
|
||||||
r"resnet\.encoder\.stages\..+\.layers\..+\.layer\..+\.attention\..+\.bias",
|
|
||||||
]
|
|
||||||
|
|
||||||
with self.subTest(msg=f"Testing {model_class}"):
|
with self.subTest(msg=f"Testing {model_class}"):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model = model_class(configs_no_init)
|
model = model_class(configs_no_init)
|
||||||
@@ -3215,37 +3177,23 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# Fails when we don't set ignore_mismatched_sizes=True
|
# Fails when we don't set ignore_mismatched_sizes=True
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
new_model = model_class.from_pretrained(tmp_dir, num_labels=42)
|
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||||
|
|
||||||
logger = logging.get_logger("transformers.modeling_utils")
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
|
||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
new_model = model_class.from_pretrained(tmp_dir, num_labels=42, ignore_mismatched_sizes=True)
|
new_model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
self.assertIn("the shapes did not match", cl.out)
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
|
||||||
for name, param in new_model.named_parameters():
|
for name, param in new_model.named_parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
param_mean = ((param.data.mean() * 1e9).round() / 1e9).item()
|
self.assertIn(
|
||||||
if not (
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
is_special_classes
|
[0.0, 1.0],
|
||||||
and any(len(re.findall(target, name)) > 0 for target in special_param_names)
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
):
|
)
|
||||||
self.assertIn(
|
|
||||||
param_mean,
|
|
||||||
[0.0, 1.0],
|
|
||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.assertGreaterEqual(
|
|
||||||
param_mean,
|
|
||||||
-1.0,
|
|
||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
||||||
)
|
|
||||||
self.assertLessEqual(
|
|
||||||
param_mean,
|
|
||||||
1.0,
|
|
||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
|
def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
|
||||||
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||||
|
|||||||
Reference in New Issue
Block a user