Fix RT-DETR weights initialization (#31724)
* Fix init for rt-detr heads * Fixup * Add separate prior_prob value to config for initialization * Add bbox init * Change to 1 / num_labels init * Adjust weights init test * Fix style for test
This commit is contained in:
committed by
GitHub
parent
b97521614a
commit
048f599f35
@@ -37,6 +37,9 @@ class RTDetrConfig(PretrainedConfig):
|
|||||||
Args:
|
Args:
|
||||||
initializer_range (`float`, *optional*, defaults to 0.01):
|
initializer_range (`float`, *optional*, defaults to 0.01):
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
initializer_bias_prior_prob (`float`, *optional*):
|
||||||
|
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
||||||
|
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
||||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers.
|
||||||
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
@@ -179,6 +182,7 @@ class RTDetrConfig(PretrainedConfig):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
initializer_range=0.01,
|
initializer_range=0.01,
|
||||||
|
initializer_bias_prior_prob=None,
|
||||||
layer_norm_eps=1e-5,
|
layer_norm_eps=1e-5,
|
||||||
batch_norm_eps=1e-5,
|
batch_norm_eps=1e-5,
|
||||||
# backbone
|
# backbone
|
||||||
@@ -239,6 +243,7 @@ class RTDetrConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.batch_norm_eps = batch_norm_eps
|
self.batch_norm_eps = batch_norm_eps
|
||||||
# backbone
|
# backbone
|
||||||
|
|||||||
@@ -1148,14 +1148,27 @@ class RTDetrPreTrainedModel(PreTrainedModel):
|
|||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initalize the weights"""
|
"""Initalize the weights"""
|
||||||
|
|
||||||
"""initialize conv/fc bias value according to a given probability value."""
|
"""initialize linear layer bias value according to a given probability value."""
|
||||||
if isinstance(module, nn.Linear) and hasattr(module, "class_embed"):
|
if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)):
|
||||||
prior_prob = self.config.initializer_range
|
if module.class_embed is not None:
|
||||||
|
for layer in module.class_embed:
|
||||||
|
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
||||||
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
||||||
nn.init.xavier_uniform_(module.weight)
|
nn.init.xavier_uniform_(layer.weight)
|
||||||
if module.bias is not None:
|
nn.init.constant_(layer.bias, bias)
|
||||||
nn.init.constant_(module.bias, bias)
|
|
||||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
if module.bbox_embed is not None:
|
||||||
|
for layer in module.bbox_embed:
|
||||||
|
nn.init.constant_(layer.layers[-1].weight, 0)
|
||||||
|
nn.init.constant_(layer.layers[-1].bias, 0)
|
||||||
|
|
||||||
|
if isinstance(module, RTDetrModel):
|
||||||
|
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
||||||
|
bias = float(-math.log((1 - prior_prob) / prior_prob))
|
||||||
|
nn.init.xavier_uniform_(module.enc_score_head.weight)
|
||||||
|
nn.init.constant_(module.enc_score_head.bias, bias)
|
||||||
|
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|||||||
@@ -584,6 +584,11 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
configs_no_init = _config_zero_init(config)
|
configs_no_init = _config_zero_init(config)
|
||||||
|
configs_no_init.initializer_bias_prior_prob = 0.2
|
||||||
|
bias_value = -1.3863 # log_e ((1 - 0.2) / 0.2)
|
||||||
|
|
||||||
|
failed_cases = []
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
# Skip the check for the backbone
|
# Skip the check for the backbone
|
||||||
@@ -594,21 +599,37 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if (
|
if ("class_embed" in name and "bias" in name) or "enc_score_head.bias" in name:
|
||||||
|
bias_tensor = torch.full_like(param.data, bias_value)
|
||||||
|
if not torch.allclose(param.data, bias_tensor, atol=1e-4):
|
||||||
|
failed_cases.append(
|
||||||
|
f"Parameter {name} of model {model_class} seems not properly initialized. "
|
||||||
|
f"Biases should be initialized to {bias_value}, got {param.data}"
|
||||||
|
)
|
||||||
|
elif (
|
||||||
"level_embed" in name
|
"level_embed" in name
|
||||||
or "sampling_offsets.bias" in name
|
or "sampling_offsets.bias" in name
|
||||||
or "value_proj" in name
|
or "value_proj" in name
|
||||||
or "output_proj" in name
|
or "output_proj" in name
|
||||||
or "reference_points" in name
|
or "reference_points" in name
|
||||||
|
or "enc_score_head.weight" in name
|
||||||
|
or ("class_embed" in name and "weight" in name)
|
||||||
or name in backbone_params
|
or name in backbone_params
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
self.assertIn(
|
else:
|
||||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
mean = param.data.mean()
|
||||||
[0.0, 1.0],
|
round_mean = (mean * 1e9).round() / 1e9
|
||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
round_mean = round_mean.item()
|
||||||
|
if round_mean not in [0.0, 1.0]:
|
||||||
|
failed_cases.append(
|
||||||
|
f"Parameter {name} of model {model_class} seems not properly initialized. "
|
||||||
|
f"Mean is {round_mean}, but should be in [0, 1]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
message = "\n" + "\n".join(failed_cases)
|
||||||
|
self.assertTrue(not failed_cases, message)
|
||||||
|
|
||||||
@parameterized.expand(["float32", "float16", "bfloat16"])
|
@parameterized.expand(["float32", "float16", "bfloat16"])
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user