Fix bug in segmentation postprocessing (#20198)

* Fix post_process_instance_segmentation
* Add test for label fusing
This commit is contained in:
Alara Dirik
2023-01-04 18:34:58 +03:00
committed by GitHub
parent 292acd71d6
commit 52c9e6af29
2 changed files with 41 additions and 13 deletions

View File

@@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertEqual(
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
)
def test_post_process_label_fusing(self):
feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = feature_extractor.post_process_panoptic_segmentation(
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0
)
unfused_segments = [el["segments_info"] for el in segmentation]
fused_segmentation = feature_extractor.post_process_panoptic_segmentation(
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1}
)
fused_segments = [el["segments_info"] for el in fused_segmentation]
for el_unfused, el_fused in zip(unfused_segments, fused_segments):
if len(el_unfused) == 0:
self.assertEqual(len(el_unfused), len(el_fused))
continue
# Get number of segments to be fused
fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}]
num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1
# Expected number of segments after fusing
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
num_segments_fused = max([el["id"] for el in el_fused])
self.assertEqual(num_segments_fused, expected_num_segments)