Maskformer post-processing fixes and improvements (#19172)
- Improves MaskFormer docs, corrects minor typos - Restructures MaskFormerFeatureExtractor.post_process_panoptic_segmentation for better readability, adds target_sizes argument for optional resizing - Adds post_process_semantic_segmentation and post_process_instance_segmentation methods. - Adds a deprecation warning to post_process_segmentation method in favour of post_process_instance_segmentation
This commit is contained in:
@@ -29,6 +29,7 @@ if is_torch_available():
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import MaskFormerFeatureExtractor
|
||||
from transformers.models.maskformer.feature_extraction_maskformer import binary_mask_to_rle
|
||||
from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput
|
||||
|
||||
if is_vision_available():
|
||||
@@ -344,6 +345,17 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
common(is_instance_map=False, segmentation_type="pil")
|
||||
common(is_instance_map=True, segmentation_type="pil")
|
||||
|
||||
def test_binary_mask_to_rle(self):
|
||||
fake_binary_mask = np.zeros((20, 50))
|
||||
fake_binary_mask[0, 20:] = 1
|
||||
fake_binary_mask[1, :15] = 1
|
||||
fake_binary_mask[5, :10] = 1
|
||||
|
||||
rle = binary_mask_to_rle(fake_binary_mask)
|
||||
self.assertEqual(len(rle), 4)
|
||||
self.assertEqual(rle[0], 21)
|
||||
self.assertEqual(rle[1], 45)
|
||||
|
||||
def test_post_process_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
@@ -373,31 +385,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
|
||||
|
||||
self.assertEqual(len(segmentation), self.feature_extract_tester.batch_size)
|
||||
self.assertEqual(
|
||||
segmentation.shape,
|
||||
segmentation[0].shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.height,
|
||||
self.feature_extract_tester.width,
|
||||
),
|
||||
)
|
||||
|
||||
target_size = (1, 4)
|
||||
target_sizes = [(1, 4) for i in range(self.feature_extract_tester.batch_size)]
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
|
||||
|
||||
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_size=target_size)
|
||||
|
||||
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
|
||||
self.assertEqual(segmentation[0].shape, target_sizes[0])
|
||||
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
|
||||
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
|
||||
|
||||
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
self.assertTrue("segments" in el)
|
||||
self.assertEqual(type(el["segments"]), list)
|
||||
self.assertTrue("segments_info" in el)
|
||||
self.assertEqual(type(el["segments_info"]), list)
|
||||
self.assertEqual(
|
||||
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user