From 52c9e6af2928cbbe78a8b8664cad716aca9f32e0 Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Wed, 4 Jan 2023 18:34:58 +0300 Subject: [PATCH] Fix bug in segmentation postprocessing (#20198) * Fix post_process_instance_segmentation * Add test for label fusing --- .../maskformer/image_processing_maskformer.py | 27 ++++++++++--------- .../test_feature_extraction_maskformer.py | 27 +++++++++++++++++++ 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index 1211cd6a82..aea9bb784b 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -1050,12 +1050,13 @@ class MaskFormerImageProcessor(BaseImageProcessor): # Get segmentation map and segment information of batch item target_size = target_sizes[i] if target_sizes is not None else None segmentation, segments = compute_segments( - mask_probs_item, - pred_scores_item, - pred_labels_item, - mask_threshold, - overlap_mask_area_threshold, - target_size, + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=[], + target_size=target_size, ) # Return segmentation map in run-length encoding (RLE) format @@ -1143,13 +1144,13 @@ class MaskFormerImageProcessor(BaseImageProcessor): # Get segmentation map and segment information of batch item target_size = target_sizes[i] if target_sizes is not None else None segmentation, segments = compute_segments( - mask_probs_item, - pred_scores_item, - pred_labels_item, - mask_threshold, - overlap_mask_area_threshold, - label_ids_to_fuse, - target_size, + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, ) results.append({"segmentation": segmentation, "segments_info": segments}) diff --git a/tests/models/maskformer/test_feature_extraction_maskformer.py b/tests/models/maskformer/test_feature_extraction_maskformer.py index 624a48ac8e..2036d9f7d2 100644 --- a/tests/models/maskformer/test_feature_extraction_maskformer.py +++ b/tests/models/maskformer/test_feature_extraction_maskformer.py @@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertEqual( el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width) ) + + def test_post_process_label_fusing(self): + feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) + outputs = self.feature_extract_tester.get_fake_maskformer_outputs() + + segmentation = feature_extractor.post_process_panoptic_segmentation( + outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0 + ) + unfused_segments = [el["segments_info"] for el in segmentation] + + fused_segmentation = feature_extractor.post_process_panoptic_segmentation( + outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1} + ) + fused_segments = [el["segments_info"] for el in fused_segmentation] + + for el_unfused, el_fused in zip(unfused_segments, fused_segments): + if len(el_unfused) == 0: + self.assertEqual(len(el_unfused), len(el_fused)) + continue + + # Get number of segments to be fused + fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}] + num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1 + # Expected number of segments after fusing + expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse + num_segments_fused = max([el["id"] for el in el_fused]) + self.assertEqual(num_segments_fused, expected_num_segments)