Fix some bugs for two stage training of deformable detr (#25045)
* Update modeling_deformable_detr.py Fix bugs for two stage training * Update modeling_deformable_detr.py * Add test_two_stage_training to DeformableDetrModelTest --------- Co-authored-by: yupeng.jia <yupeng.jia@momenta.ai>
This commit is contained in:
@@ -544,6 +544,21 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_two_stage_training(self):
|
||||
model_class = DeformableDetrForObjectDetection
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
config.two_stage = True
|
||||
config.auxiliary_loss = True
|
||||
config.with_box_refine = True
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user