Tests for MaskFormerFeatureExtractor's post_process*** methods (#15929)
* proper tests for post_process*** methods in feature extractor * mask th == 0 * Update tests/maskformer/test_feature_extraction_maskformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f0aacc140b
commit
040c11f6da
@@ -29,6 +29,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from transformers import MaskFormerFeatureExtractor
|
from transformers import MaskFormerFeatureExtractor
|
||||||
|
from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -61,6 +62,12 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
|
|||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
self.size_divisibility = 0
|
self.size_divisibility = 0
|
||||||
|
# for the post_process_functions
|
||||||
|
self.batch_size = 2
|
||||||
|
self.num_queries = 3
|
||||||
|
self.num_classes = 2
|
||||||
|
self.height = 3
|
||||||
|
self.width = 4
|
||||||
|
|
||||||
def prepare_feat_extract_dict(self):
|
def prepare_feat_extract_dict(self):
|
||||||
return {
|
return {
|
||||||
@@ -104,6 +111,13 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
|
|||||||
|
|
||||||
return expected_height, expected_width
|
return expected_height, expected_width
|
||||||
|
|
||||||
|
def get_fake_maskformer_outputs(self):
|
||||||
|
return MaskFormerForInstanceSegmentationOutput(
|
||||||
|
# +1 for null class
|
||||||
|
class_queries_logits=torch.randn((self.batch_size, self.num_queries, self.num_classes + 1)),
|
||||||
|
masks_queries_logits=torch.randn((self.batch_size, self.num_queries, self.height, self.width)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
@@ -301,3 +315,61 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
|||||||
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1])
|
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1])
|
||||||
self.assertEqual(mask_labels.shape[1], class_labels.shape[1])
|
self.assertEqual(mask_labels.shape[1], class_labels.shape[1])
|
||||||
self.assertEqual(mask_labels.shape[1], num_classes)
|
self.assertEqual(mask_labels.shape[1], num_classes)
|
||||||
|
|
||||||
|
def test_post_process_segmentation(self):
|
||||||
|
fature_extractor = self.feature_extraction_class()
|
||||||
|
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||||
|
segmentation = fature_extractor.post_process_segmentation(outputs)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
segmentation.shape,
|
||||||
|
(
|
||||||
|
self.feature_extract_tester.batch_size,
|
||||||
|
self.feature_extract_tester.num_classes,
|
||||||
|
self.feature_extract_tester.height,
|
||||||
|
self.feature_extract_tester.width,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
target_size = (1, 4)
|
||||||
|
segmentation = fature_extractor.post_process_segmentation(outputs, target_size=target_size)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
segmentation.shape,
|
||||||
|
(self.feature_extract_tester.batch_size, self.feature_extract_tester.num_classes, *target_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_post_process_semantic_segmentation(self):
|
||||||
|
fature_extractor = self.feature_extraction_class()
|
||||||
|
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||||
|
|
||||||
|
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
segmentation.shape,
|
||||||
|
(
|
||||||
|
self.feature_extract_tester.batch_size,
|
||||||
|
self.feature_extract_tester.height,
|
||||||
|
self.feature_extract_tester.width,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
target_size = (1, 4)
|
||||||
|
|
||||||
|
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_size=target_size)
|
||||||
|
|
||||||
|
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
|
||||||
|
|
||||||
|
def test_post_process_panoptic_segmentation(self):
|
||||||
|
fature_extractor = self.feature_extraction_class()
|
||||||
|
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||||
|
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
|
||||||
|
|
||||||
|
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
|
||||||
|
for el in segmentation:
|
||||||
|
self.assertTrue("segmentation" in el)
|
||||||
|
self.assertTrue("segments" in el)
|
||||||
|
self.assertEqual(type(el["segments"]), list)
|
||||||
|
self.assertEqual(
|
||||||
|
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
|
||||||
|
)
|
||||||
|
|||||||
@@ -404,23 +404,3 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
self.assertTrue(outputs.loss is not None)
|
self.assertTrue(outputs.loss is not None)
|
||||||
|
|
||||||
def test_panoptic_segmentation(self):
|
|
||||||
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
|
||||||
feature_extractor = self.default_feature_extractor
|
|
||||||
|
|
||||||
inputs = feature_extractor(
|
|
||||||
[np.zeros((3, 384, 384)), np.zeros((3, 384, 384))],
|
|
||||||
annotations=[
|
|
||||||
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
|
||||||
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
|
||||||
],
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**inputs)
|
|
||||||
|
|
||||||
panoptic_segmentation = feature_extractor.post_process_panoptic_segmentation(outputs)
|
|
||||||
|
|
||||||
self.assertTrue(len(panoptic_segmentation) == 2)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user