Fix bug in segmentation postprocessing (#20198)
* Fix post_process_instance_segmentation * Add test for label fusing
This commit is contained in:
@@ -1050,12 +1050,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
# Get segmentation map and segment information of batch item
|
# Get segmentation map and segment information of batch item
|
||||||
target_size = target_sizes[i] if target_sizes is not None else None
|
target_size = target_sizes[i] if target_sizes is not None else None
|
||||||
segmentation, segments = compute_segments(
|
segmentation, segments = compute_segments(
|
||||||
mask_probs_item,
|
mask_probs=mask_probs_item,
|
||||||
pred_scores_item,
|
pred_scores=pred_scores_item,
|
||||||
pred_labels_item,
|
pred_labels=pred_labels_item,
|
||||||
mask_threshold,
|
mask_threshold=mask_threshold,
|
||||||
overlap_mask_area_threshold,
|
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||||
target_size,
|
label_ids_to_fuse=[],
|
||||||
|
target_size=target_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return segmentation map in run-length encoding (RLE) format
|
# Return segmentation map in run-length encoding (RLE) format
|
||||||
@@ -1143,13 +1144,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
|||||||
# Get segmentation map and segment information of batch item
|
# Get segmentation map and segment information of batch item
|
||||||
target_size = target_sizes[i] if target_sizes is not None else None
|
target_size = target_sizes[i] if target_sizes is not None else None
|
||||||
segmentation, segments = compute_segments(
|
segmentation, segments = compute_segments(
|
||||||
mask_probs_item,
|
mask_probs=mask_probs_item,
|
||||||
pred_scores_item,
|
pred_scores=pred_scores_item,
|
||||||
pred_labels_item,
|
pred_labels=pred_labels_item,
|
||||||
mask_threshold,
|
mask_threshold=mask_threshold,
|
||||||
overlap_mask_area_threshold,
|
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||||
label_ids_to_fuse,
|
label_ids_to_fuse=label_ids_to_fuse,
|
||||||
target_size,
|
target_size=target_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
results.append({"segmentation": segmentation, "segments_info": segments})
|
results.append({"segmentation": segmentation, "segments_info": segments})
|
||||||
|
|||||||
@@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
|
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_post_process_label_fusing(self):
|
||||||
|
feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||||
|
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||||
|
|
||||||
|
segmentation = feature_extractor.post_process_panoptic_segmentation(
|
||||||
|
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0
|
||||||
|
)
|
||||||
|
unfused_segments = [el["segments_info"] for el in segmentation]
|
||||||
|
|
||||||
|
fused_segmentation = feature_extractor.post_process_panoptic_segmentation(
|
||||||
|
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1}
|
||||||
|
)
|
||||||
|
fused_segments = [el["segments_info"] for el in fused_segmentation]
|
||||||
|
|
||||||
|
for el_unfused, el_fused in zip(unfused_segments, fused_segments):
|
||||||
|
if len(el_unfused) == 0:
|
||||||
|
self.assertEqual(len(el_unfused), len(el_fused))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get number of segments to be fused
|
||||||
|
fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}]
|
||||||
|
num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1
|
||||||
|
# Expected number of segments after fusing
|
||||||
|
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
|
||||||
|
num_segments_fused = max([el["id"] for el in el_fused])
|
||||||
|
self.assertEqual(num_segments_fused, expected_num_segments)
|
||||||
|
|||||||
Reference in New Issue
Block a user