From 2875fa971cf480f4411d5684844eb5aa8f9870ce Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 20 Dec 2022 16:46:50 +0100 Subject: [PATCH] [SegFormer] Add support for segmentation masks with one label (#20279) * Add support for binary segmentation * Fix loss calculation and add test * Remove space * use fstring Co-authored-by: Niels Rogge Co-authored-by: Niels Rogge --- .../models/segformer/modeling_segformer.py | 19 ++++++++++++------- .../segformer/test_modeling_segformer.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 194ad307a5..57eb9fa6c4 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -806,15 +806,20 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): loss = None if labels is not None: - if not self.config.num_labels > 1: - raise ValueError("The number of labels should be greater than one") - else: - # upsample logits to the images' original size - upsampled_logits = nn.functional.interpolate( - logits, size=labels.shape[-2:], mode="bilinear", align_corners=False - ) + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if self.config.num_labels > 1: loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) loss = loss_fct(upsampled_logits, labels) + elif self.config.num_labels == 1: + valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float() + loss_fct = BCEWithLogitsLoss(reduction="none") + loss = loss_fct(upsampled_logits.squeeze(1), labels.float()) + loss = (loss * valid_mask).mean() + else: + raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}") if not return_dict: if output_hidden_states: diff --git a/tests/models/segformer/test_modeling_segformer.py b/tests/models/segformer/test_modeling_segformer.py index c7e3e8a92e..6037170fb1 100644 --- a/tests/models/segformer/test_modeling_segformer.py +++ b/tests/models/segformer/test_modeling_segformer.py @@ -140,6 +140,16 @@ class SegformerModelTester: self.parent.assertEqual( result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4) ) + self.parent.assertGreater(result.loss, 0.0) + + def create_and_check_for_binary_image_segmentation(self, config, pixel_values, labels): + config.num_labels = 1 + model = SegformerForSemanticSegmentation(config=config) + model.to(torch_device) + model.eval() + labels = torch.randint(0, 1, (self.batch_size, self.image_size, self.image_size)).to(torch_device) + result = model(pixel_values, labels=labels) + self.parent.assertGreater(result.loss, 0.0) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -177,6 +187,10 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_for_binary_image_segmentation(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_binary_image_segmentation(*config_and_inputs) + def test_for_image_segmentation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)