fix auxiliary loss training in DetrSegmentation (#28354)
* fix auxiliary loss training in detrSegmentation * add auxiliary_loss testing
This commit is contained in:
committed by
GitHub
parent
8604dd308d
commit
357971ec36
@@ -1826,9 +1826,9 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
||||
outputs_loss["pred_masks"] = pred_masks
|
||||
if self.config.auxiliary_loss:
|
||||
intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
|
||||
outputs_class = self.class_labels_classifier(intermediate)
|
||||
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_class = self.detr.class_labels_classifier(intermediate)
|
||||
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
|
||||
auxiliary_outputs = self.detr._set_aux_loss(outputs_class, outputs_coord)
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
|
||||
@@ -399,6 +399,22 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
self.assertIsNotNone(decoder_attentions.grad)
|
||||
self.assertIsNotNone(cross_attentions.grad)
|
||||
|
||||
def test_forward_auxiliary_loss(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.auxiliary_loss = True
|
||||
|
||||
# only test for object detection and segmentation model
|
||||
for model_class in self.all_model_classes[1:]:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
self.assertIsNotNone(outputs.auxiliary_outputs)
|
||||
self.assertEqual(len(outputs.auxiliary_outputs), self.model_tester.num_hidden_layers - 1)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user