Fix bug in segmentation postprocessing (#20198)
* Fix post_process_instance_segmentation * Add test for label fusing
This commit is contained in:
@@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
self.assertEqual(
|
||||
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
|
||||
)
|
||||
|
||||
def test_post_process_label_fusing(self):
|
||||
feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
|
||||
segmentation = feature_extractor.post_process_panoptic_segmentation(
|
||||
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0
|
||||
)
|
||||
unfused_segments = [el["segments_info"] for el in segmentation]
|
||||
|
||||
fused_segmentation = feature_extractor.post_process_panoptic_segmentation(
|
||||
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1}
|
||||
)
|
||||
fused_segments = [el["segments_info"] for el in fused_segmentation]
|
||||
|
||||
for el_unfused, el_fused in zip(unfused_segments, fused_segments):
|
||||
if len(el_unfused) == 0:
|
||||
self.assertEqual(len(el_unfused), len(el_fused))
|
||||
continue
|
||||
|
||||
# Get number of segments to be fused
|
||||
fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}]
|
||||
num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1
|
||||
# Expected number of segments after fusing
|
||||
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
|
||||
num_segments_fused = max([el["id"] for el in el_fused])
|
||||
self.assertEqual(num_segments_fused, expected_num_segments)
|
||||
|
||||
Reference in New Issue
Block a user