Fix some bugs for two stage training of deformable detr (#25045)

* Update modeling_deformable_detr.py

Fix bugs for two stage training

* Update modeling_deformable_detr.py

* Add test_two_stage_training to DeformableDetrModelTest

---------

Co-authored-by: yupeng.jia <yupeng.jia@momenta.ai>
This commit is contained in:
Yupeng Jia
2023-08-02 18:30:36 +08:00
committed by GitHub
parent 1b35409768
commit 8021c684ec
3 changed files with 25 additions and 15 deletions

View File

@@ -544,6 +544,21 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_two_stage_training(self):
model_class = DeformableDetrForObjectDetection
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
config.two_stage = True
config.auxiliary_loss = True
config.with_box_refine = True
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
TOLERANCE = 1e-4