fix auxiliary loss training in DetrSegmentation (#28354)
* fix auxiliary loss training in detrSegmentation * add auxiliary_loss testing
This commit is contained in:
committed by
GitHub
parent
8604dd308d
commit
357971ec36
@@ -1826,9 +1826,9 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|||||||
outputs_loss["pred_masks"] = pred_masks
|
outputs_loss["pred_masks"] = pred_masks
|
||||||
if self.config.auxiliary_loss:
|
if self.config.auxiliary_loss:
|
||||||
intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
|
intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
|
||||||
outputs_class = self.class_labels_classifier(intermediate)
|
outputs_class = self.detr.class_labels_classifier(intermediate)
|
||||||
outputs_coord = self.bbox_predictor(intermediate).sigmoid()
|
outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
|
||||||
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
|
auxiliary_outputs = self.detr._set_aux_loss(outputs_class, outputs_coord)
|
||||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||||
|
|
||||||
loss_dict = criterion(outputs_loss, labels)
|
loss_dict = criterion(outputs_loss, labels)
|
||||||
|
|||||||
@@ -399,6 +399,22 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
self.assertIsNotNone(decoder_attentions.grad)
|
self.assertIsNotNone(decoder_attentions.grad)
|
||||||
self.assertIsNotNone(cross_attentions.grad)
|
self.assertIsNotNone(cross_attentions.grad)
|
||||||
|
|
||||||
|
def test_forward_auxiliary_loss(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.auxiliary_loss = True
|
||||||
|
|
||||||
|
# only test for object detection and segmentation model
|
||||||
|
for model_class in self.all_model_classes[1:]:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
self.assertIsNotNone(outputs.auxiliary_outputs)
|
||||||
|
self.assertEqual(len(outputs.auxiliary_outputs), self.model_tester.num_hidden_layers - 1)
|
||||||
|
|
||||||
def test_forward_signature(self):
|
def test_forward_signature(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user