From 040c11f6dac72bc3088498aa19184da677563424 Mon Sep 17 00:00:00 2001 From: Francesco Saverio Zuppichini Date: Fri, 4 Mar 2022 18:04:19 +0100 Subject: [PATCH] Tests for MaskFormerFeatureExtractor's post_process*** methods (#15929) * proper tests for post_process*** methods in feature extractor * mask th == 0 * Update tests/maskformer/test_feature_extraction_maskformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../test_feature_extraction_maskformer.py | 72 +++++++++++++++++++ tests/maskformer/test_modeling_maskformer.py | 20 ------ 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/tests/maskformer/test_feature_extraction_maskformer.py b/tests/maskformer/test_feature_extraction_maskformer.py index 3dea525585..ad4b5d6b0c 100644 --- a/tests/maskformer/test_feature_extraction_maskformer.py +++ b/tests/maskformer/test_feature_extraction_maskformer.py @@ -29,6 +29,7 @@ if is_torch_available(): if is_vision_available(): from transformers import MaskFormerFeatureExtractor + from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput if is_vision_available(): from PIL import Image @@ -61,6 +62,12 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): self.image_mean = image_mean self.image_std = image_std self.size_divisibility = 0 + # for the post_process_functions + self.batch_size = 2 + self.num_queries = 3 + self.num_classes = 2 + self.height = 3 + self.width = 4 def prepare_feat_extract_dict(self): return { @@ -104,6 +111,13 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): return expected_height, expected_width + def get_fake_maskformer_outputs(self): + return MaskFormerForInstanceSegmentationOutput( + # +1 for null class + class_queries_logits=torch.randn((self.batch_size, self.num_queries, self.num_classes + 1)), + masks_queries_logits=torch.randn((self.batch_size, self.num_queries, self.height, self.width)), + ) + @require_torch @require_vision @@ -301,3 +315,61 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1]) self.assertEqual(mask_labels.shape[1], class_labels.shape[1]) self.assertEqual(mask_labels.shape[1], num_classes) + + def test_post_process_segmentation(self): + fature_extractor = self.feature_extraction_class() + outputs = self.feature_extract_tester.get_fake_maskformer_outputs() + segmentation = fature_extractor.post_process_segmentation(outputs) + + self.assertEqual( + segmentation.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_classes, + self.feature_extract_tester.height, + self.feature_extract_tester.width, + ), + ) + + target_size = (1, 4) + segmentation = fature_extractor.post_process_segmentation(outputs, target_size=target_size) + + self.assertEqual( + segmentation.shape, + (self.feature_extract_tester.batch_size, self.feature_extract_tester.num_classes, *target_size), + ) + + def test_post_process_semantic_segmentation(self): + fature_extractor = self.feature_extraction_class() + outputs = self.feature_extract_tester.get_fake_maskformer_outputs() + + segmentation = fature_extractor.post_process_semantic_segmentation(outputs) + + self.assertEqual( + segmentation.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.height, + self.feature_extract_tester.width, + ), + ) + + target_size = (1, 4) + + segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_size=target_size) + + self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size)) + + def test_post_process_panoptic_segmentation(self): + fature_extractor = self.feature_extraction_class() + outputs = self.feature_extract_tester.get_fake_maskformer_outputs() + segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0) + + self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments" in el) + self.assertEqual(type(el["segments"]), list) + self.assertEqual( + el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width) + ) diff --git a/tests/maskformer/test_modeling_maskformer.py b/tests/maskformer/test_modeling_maskformer.py index 67151ead6f..f2e1f56f0f 100644 --- a/tests/maskformer/test_modeling_maskformer.py +++ b/tests/maskformer/test_modeling_maskformer.py @@ -404,23 +404,3 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): outputs = model(**inputs) self.assertTrue(outputs.loss is not None) - - def test_panoptic_segmentation(self): - model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() - feature_extractor = self.default_feature_extractor - - inputs = feature_extractor( - [np.zeros((3, 384, 384)), np.zeros((3, 384, 384))], - annotations=[ - {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, - {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, - ], - return_tensors="pt", - ) - - with torch.no_grad(): - outputs = model(**inputs) - - panoptic_segmentation = feature_extractor.post_process_panoptic_segmentation(outputs) - - self.assertTrue(len(panoptic_segmentation) == 2)