Enabling MaskFormer in pipelines (#15917)
* Enabling MaskFormer in ppipelines No AutoModel though :( * Ooops local file.
This commit is contained in:
@@ -110,7 +110,18 @@ class ImageSegmentationPipeline(Pipeline):
|
|||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5):
|
def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5):
|
||||||
if hasattr(self.feature_extractor, "post_process_segmentation"):
|
if hasattr(self.feature_extractor, "post_process_panoptic_segmentation"):
|
||||||
|
outputs = self.feature_extractor.post_process_panoptic_segmentation(
|
||||||
|
model_outputs, is_thing_map=self.model.config.id2label
|
||||||
|
)[0]
|
||||||
|
annotation = []
|
||||||
|
segmentation = outputs["segmentation"]
|
||||||
|
for segment in outputs["segments"]:
|
||||||
|
mask = (segmentation == segment["id"]) * 255
|
||||||
|
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
|
||||||
|
label = self.model.config.id2label[segment["category_id"]]
|
||||||
|
annotation.append({"mask": mask, "label": label, "score": None})
|
||||||
|
elif hasattr(self.feature_extractor, "post_process_segmentation"):
|
||||||
# Panoptic
|
# Panoptic
|
||||||
raw_annotations = self.feature_extractor.post_process_segmentation(
|
raw_annotations = self.feature_extractor.post_process_segmentation(
|
||||||
model_outputs, model_outputs["target_size"], threshold=threshold, mask_threshold=0.5
|
model_outputs, model_outputs["target_size"], threshold=threshold, mask_threshold=0.5
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import hashlib
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||||
@@ -308,3 +309,35 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
|||||||
{"score": 0.9994, "label": "cat", "mask": "88b37bd2202c750cc9dd191518050a9b0ca5228c"},
|
{"score": 0.9994, "label": "cat", "mask": "88b37bd2202c750cc9dd191518050a9b0ca5228c"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_maskformer(self):
|
||||||
|
threshold = 0.999
|
||||||
|
model_id = "facebook/maskformer-swin-base-ade"
|
||||||
|
|
||||||
|
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
|
||||||
|
|
||||||
|
model = MaskFormerForInstanceSegmentation.from_pretrained(model_id)
|
||||||
|
feature_extractor = MaskFormerFeatureExtractor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
image_segmenter = pipeline("image-segmentation", model=model, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||||
|
outputs = image_segmenter(image[0]["file"], threshold=threshold)
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
o["mask"] = hashimage(o["mask"])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"mask": "20d1b9480d1dc1501dbdcfdff483e370", "label": "wall", "score": None},
|
||||||
|
{"mask": "0f902fbc66a0ff711ea455b0e4943adf", "label": "house", "score": None},
|
||||||
|
{"mask": "4537bdc07d47d84b3f8634b7ada37bd4", "label": "grass", "score": None},
|
||||||
|
{"mask": "b7ac77dfae44a904b479a0926a2acaf7", "label": "tree", "score": None},
|
||||||
|
{"mask": "e9bedd56bd40650fb263ce03eb621079", "label": "plant", "score": None},
|
||||||
|
{"mask": "37a609f8c9c1b8db91fbff269f428b20", "label": "road, route", "score": None},
|
||||||
|
{"mask": "0d8cdfd63bae8bf6e4344d460a2fa711", "label": "sky", "score": None},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user