🚨🚨 Fix initialization of Mask2Former (#38864)
* Correctly fix init Co-authored-by: BUI Van Tuan <buivantuan07@gmail.com> * add back the block, breaking BC but this is correct author's code * override the test for params needing it --------- Co-authored-by: BUI Van Tuan <buivantuan07@gmail.com>
This commit is contained in:
@@ -2127,30 +2127,20 @@ class Mask2FormerPreTrainedModel(PreTrainedModel):
|
|||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
if p.dim() > 1:
|
if p.dim() > 1:
|
||||||
nn.init.xavier_uniform_(p, gain=xavier_std)
|
nn.init.xavier_uniform_(p, gain=xavier_std)
|
||||||
|
module.cross_attn.in_proj_bias.data.zero_()
|
||||||
elif isinstance(module, Mask2FormerPixelLevelModule):
|
|
||||||
for submodule in module.modules():
|
|
||||||
if isinstance(submodule, (nn.Conv2d, nn.Linear)):
|
|
||||||
submodule.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if submodule.bias is not None:
|
|
||||||
submodule.bias.data.zero_()
|
|
||||||
|
|
||||||
elif isinstance(module, Mask2FormerPixelDecoder):
|
elif isinstance(module, Mask2FormerPixelDecoder):
|
||||||
for p in module.parameters():
|
|
||||||
if p.dim() > 1:
|
|
||||||
nn.init.xavier_uniform_(p)
|
|
||||||
nn.init.normal_(module.level_embed, std=0)
|
nn.init.normal_(module.level_embed, std=0)
|
||||||
|
|
||||||
elif isinstance(module, Mask2FormerPixelDecoderEncoderOnly):
|
|
||||||
for p in module.parameters():
|
|
||||||
if p.dim() > 1:
|
|
||||||
nn.init.xavier_uniform_(p)
|
|
||||||
|
|
||||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
elif isinstance(module, nn.Embedding):
|
elif isinstance(module, nn.Embedding):
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
|
|||||||
@@ -324,12 +324,7 @@ def load_backbone(config):
|
|||||||
raise ValueError("Cannot specify both config.backbone_config and config.backbone")
|
raise ValueError("Cannot specify both config.backbone_config and config.backbone")
|
||||||
|
|
||||||
# If any of thhe following are set, then the config passed in is from a model which contains a backbone.
|
# If any of thhe following are set, then the config passed in is from a model which contains a backbone.
|
||||||
if (
|
if backbone_config is None and use_timm_backbone is None and backbone_checkpoint is None:
|
||||||
backbone_config is None
|
|
||||||
and use_timm_backbone is None
|
|
||||||
and backbone_checkpoint is None
|
|
||||||
and backbone_checkpoint is None
|
|
||||||
):
|
|
||||||
return AutoBackbone.from_config(config=config, **backbone_kwargs)
|
return AutoBackbone.from_config(config=config, **backbone_kwargs)
|
||||||
|
|
||||||
# config from the parent model that has a backbone
|
# config from the parent model that has a backbone
|
||||||
|
|||||||
@@ -590,15 +590,14 @@ class DeformableDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if param.requires_grad:
|
if (
|
||||||
if (
|
"level_embed" in name
|
||||||
"level_embed" in name
|
or "sampling_offsets.bias" in name
|
||||||
or "sampling_offsets.bias" in name
|
or "value_proj" in name
|
||||||
or "value_proj" in name
|
or "output_proj" in name
|
||||||
or "output_proj" in name
|
or "reference_points" in name
|
||||||
or "reference_points" in name
|
):
|
||||||
):
|
continue
|
||||||
continue
|
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
[0.0, 1.0],
|
[0.0, 1.0],
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor
|
from tests.test_modeling_common import floats_tensor
|
||||||
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
|
from transformers import AutoModelForImageClassification, Mask2FormerConfig, is_torch_available, is_vision_available
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_timm,
|
require_timm,
|
||||||
@@ -33,7 +33,7 @@ from transformers.testing_utils import (
|
|||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -350,6 +350,58 @@ class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
elif model.__class__.__name__ == "Mask2FormerForUniversalSegmentation":
|
elif model.__class__.__name__ == "Mask2FormerForUniversalSegmentation":
|
||||||
self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3])
|
self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3])
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if (
|
||||||
|
"self_attn.sampling_offsets.bias" in name
|
||||||
|
or "self_attn.value_proj.weight" in name
|
||||||
|
or "self_attn.output_proj.weight" in name
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialization_pretrained_backbone(self):
|
||||||
|
backbone_name = "microsoft/resnet-18"
|
||||||
|
|
||||||
|
# load Mask2Former config with a pretrained backbone
|
||||||
|
config = Mask2FormerConfig(
|
||||||
|
backbone=backbone_name,
|
||||||
|
use_pretrained_backbone=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load pretrained backbone
|
||||||
|
backbone_model = AutoModelForImageClassification.from_pretrained(backbone_name, device_map=torch_device)
|
||||||
|
|
||||||
|
def params_match(params1, params2):
|
||||||
|
return all((p1 == p2).all() for p1, p2 in zip(params1, params2))
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
if model.__class__.__name__ == "Mask2FormerModel":
|
||||||
|
self.assertTrue(
|
||||||
|
params_match(
|
||||||
|
backbone_model.base_model.encoder.parameters(),
|
||||||
|
model.pixel_level_module.encoder.encoder.parameters(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif model.__class__.__name__ == "Mask2FormerForUniversalSegmentation":
|
||||||
|
self.assertTrue(
|
||||||
|
params_match(
|
||||||
|
backbone_model.base_model.encoder.parameters(),
|
||||||
|
model.model.pixel_level_module.encoder.encoder.parameters(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TOLERANCE = 1e-4
|
TOLERANCE = 1e-4
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user