From 7598791c092d49555ec2aae1c92cf08a2eadb9e9 Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Wed, 5 Oct 2022 23:25:58 +0300 Subject: [PATCH] Fix MaskFormer failing postprocess tests (#19354) Ensures post_process_instance_segmentation and post_process_panoptic_segmentation methods return a tensor of shape (target_height, target_width) filled with -1 values if no segment with score > threshold is found. --- .../models/maskformer/feature_extraction_maskformer.py | 10 ++++++---- .../maskformer/test_feature_extraction_maskformer.py | 7 ++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/maskformer/feature_extraction_maskformer.py b/src/transformers/models/maskformer/feature_extraction_maskformer.py index c86fce646b..1808c967e3 100644 --- a/src/transformers/models/maskformer/feature_extraction_maskformer.py +++ b/src/transformers/models/maskformer/feature_extraction_maskformer.py @@ -772,8 +772,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM # No mask found if mask_probs_item.shape[0] <= 0: - segmentation = None - segments: List[Dict] = [] + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) continue # Get segmentation map and segment information of batch item @@ -860,8 +861,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM # No mask found if mask_probs_item.shape[0] <= 0: - segmentation = None - segments: List[Dict] = [] + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) continue # Get segmentation map and segment information of batch item diff --git a/tests/models/maskformer/test_feature_extraction_maskformer.py b/tests/models/maskformer/test_feature_extraction_maskformer.py index 694d272f32..fbafa9af15 100644 --- a/tests/models/maskformer/test_feature_extraction_maskformer.py +++ b/tests/models/maskformer/test_feature_extraction_maskformer.py @@ -401,10 +401,11 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest @unittest.skip("Fix me Alara!") def test_post_process_panoptic_segmentation(self): - fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) + feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) outputs = self.feature_extract_tester.get_fake_maskformer_outputs() - segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, threshold=0) - + segmentation = feature_extractor.post_process_panoptic_segmentation(outputs, threshold=0) + print(len(segmentation)) + print(self.feature_extract_tester.batch_size) self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size) for el in segmentation: self.assertTrue("segmentation" in el)