Fix RT-DETR weights initialization (#31724)

* Fix init for rt-detr heads

* Fixup

* Add separate prior_prob value to config for initialization

* Add bbox init

* Change to 1 / num_labels init

* Adjust weights init test

* Fix style for test
This commit is contained in:
Pavel Iakubovskii
2024-07-03 14:29:02 +01:00
committed by GitHub
parent b97521614a
commit 048f599f35
3 changed files with 52 additions and 13 deletions

View File

@@ -584,6 +584,11 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
configs_no_init.initializer_bias_prior_prob = 0.2
bias_value = -1.3863 # log_e ((1 - 0.2) / 0.2)
failed_cases = []
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
@@ -594,20 +599,36 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
for name, param in model.named_parameters():
if param.requires_grad:
if (
if ("class_embed" in name and "bias" in name) or "enc_score_head.bias" in name:
bias_tensor = torch.full_like(param.data, bias_value)
if not torch.allclose(param.data, bias_tensor, atol=1e-4):
failed_cases.append(
f"Parameter {name} of model {model_class} seems not properly initialized. "
f"Biases should be initialized to {bias_value}, got {param.data}"
)
elif (
"level_embed" in name
or "sampling_offsets.bias" in name
or "value_proj" in name
or "output_proj" in name
or "reference_points" in name
or "enc_score_head.weight" in name
or ("class_embed" in name and "weight" in name)
or name in backbone_params
):
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
else:
mean = param.data.mean()
round_mean = (mean * 1e9).round() / 1e9
round_mean = round_mean.item()
if round_mean not in [0.0, 1.0]:
failed_cases.append(
f"Parameter {name} of model {model_class} seems not properly initialized. "
f"Mean is {round_mean}, but should be in [0, 1]"
)
message = "\n" + "\n".join(failed_cases)
self.assertTrue(not failed_cases, message)
@parameterized.expand(["float32", "float16", "bfloat16"])
@require_torch_gpu