From 3822e4a563d154df8277b3adbb38794bab302693 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 3 Mar 2022 16:31:41 +0100 Subject: [PATCH] Enabling MaskFormer in pipelines (#15917) * Enabling MaskFormer in ppipelines No AutoModel though :( * Ooops local file. --- .../feature_extraction_maskformer.py | 2 +- .../pipelines/image_segmentation.py | 13 +++++++- .../test_pipelines_image_segmentation.py | 33 +++++++++++++++++++ 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/maskformer/feature_extraction_maskformer.py b/src/transformers/models/maskformer/feature_extraction_maskformer.py index 7a0e21a93a..fce59b0a4b 100644 --- a/src/transformers/models/maskformer/feature_extraction_maskformer.py +++ b/src/transformers/models/maskformer/feature_extraction_maskformer.py @@ -565,5 +565,5 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ) if is_stuff: stuff_memory_list[pred_class] = current_segment_id - results.append({"segmentation": segmentation, "segments": segments}) + results.append({"segmentation": segmentation, "segments": segments}) return results diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index a1eef2e3fa..2f4e6e09ab 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -110,7 +110,18 @@ class ImageSegmentationPipeline(Pipeline): return model_outputs def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5): - if hasattr(self.feature_extractor, "post_process_segmentation"): + if hasattr(self.feature_extractor, "post_process_panoptic_segmentation"): + outputs = self.feature_extractor.post_process_panoptic_segmentation( + model_outputs, is_thing_map=self.model.config.id2label + )[0] + annotation = [] + segmentation = outputs["segmentation"] + for segment in outputs["segments"]: + mask = (segmentation == segment["id"]) * 255 + mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L") + label = self.model.config.id2label[segment["category_id"]] + annotation.append({"mask": mask, "label": label, "score": None}) + elif hasattr(self.feature_extractor, "post_process_segmentation"): # Panoptic raw_annotations = self.feature_extractor.post_process_segmentation( model_outputs, model_outputs["target_size"], threshold=threshold, mask_threshold=0.5 diff --git a/tests/pipelines/test_pipelines_image_segmentation.py b/tests/pipelines/test_pipelines_image_segmentation.py index 1677cb53c3..d0f9c4f946 100644 --- a/tests/pipelines/test_pipelines_image_segmentation.py +++ b/tests/pipelines/test_pipelines_image_segmentation.py @@ -16,6 +16,7 @@ import hashlib import unittest import datasets +from datasets import load_dataset from transformers import ( MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, @@ -308,3 +309,35 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa {"score": 0.9994, "label": "cat", "mask": "88b37bd2202c750cc9dd191518050a9b0ca5228c"}, ], ) + + @require_torch + @slow + def test_maskformer(self): + threshold = 0.999 + model_id = "facebook/maskformer-swin-base-ade" + + from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation + + model = MaskFormerForInstanceSegmentation.from_pretrained(model_id) + feature_extractor = MaskFormerFeatureExtractor.from_pretrained(model_id) + + image_segmenter = pipeline("image-segmentation", model=model, feature_extractor=feature_extractor) + + image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + outputs = image_segmenter(image[0]["file"], threshold=threshold) + + for o in outputs: + o["mask"] = hashimage(o["mask"]) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"mask": "20d1b9480d1dc1501dbdcfdff483e370", "label": "wall", "score": None}, + {"mask": "0f902fbc66a0ff711ea455b0e4943adf", "label": "house", "score": None}, + {"mask": "4537bdc07d47d84b3f8634b7ada37bd4", "label": "grass", "score": None}, + {"mask": "b7ac77dfae44a904b479a0926a2acaf7", "label": "tree", "score": None}, + {"mask": "e9bedd56bd40650fb263ce03eb621079", "label": "plant", "score": None}, + {"mask": "37a609f8c9c1b8db91fbff269f428b20", "label": "road, route", "score": None}, + {"mask": "0d8cdfd63bae8bf6e4344d460a2fa711", "label": "sky", "score": None}, + ], + )