From 026866df92afe40cdf928839864111015a62d3b5 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Fri, 8 Oct 2021 09:59:53 +0200 Subject: [PATCH] Image Segmentation pipeline (#13828) * Implement img seg pipeline * Update src/transformers/pipelines/image_segmentation.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/image_segmentation.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update output shape with individual masks * Rm dev change * Remove loops in test Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> --- docs/source/main_classes/pipelines.rst | 8 + docs/source/model_doc/auto.rst | 7 + src/transformers/__init__.py | 6 + src/transformers/modelcard.py | 3 + src/transformers/models/auto/__init__.py | 4 + src/transformers/models/auto/modeling_auto.py | 17 ++ .../models/detr/feature_extraction_detr.py | 44 +++- src/transformers/pipelines/__init__.py | 8 + .../pipelines/image_segmentation.py | 165 ++++++++++++ src/transformers/utils/dummy_pt_objects.py | 12 + tests/test_pipelines_image_segmentation.py | 241 ++++++++++++++++++ 11 files changed, 514 insertions(+), 1 deletion(-) create mode 100644 src/transformers/pipelines/image_segmentation.py create mode 100644 tests/test_pipelines_image_segmentation.py diff --git a/docs/source/main_classes/pipelines.rst b/docs/source/main_classes/pipelines.rst index 251f0fc53f..1bc60d340f 100644 --- a/docs/source/main_classes/pipelines.rst +++ b/docs/source/main_classes/pipelines.rst @@ -29,6 +29,7 @@ There are two categories of pipeline abstractions to be aware about: - :class:`~transformers.FeatureExtractionPipeline` - :class:`~transformers.FillMaskPipeline` - :class:`~transformers.ImageClassificationPipeline` + - :class:`~transformers.ImageSegmentationPipeline` - :class:`~transformers.ObjectDetectionPipeline` - :class:`~transformers.QuestionAnsweringPipeline` - :class:`~transformers.SummarizationPipeline` @@ -137,6 +138,13 @@ ImageClassificationPipeline :special-members: __call__ :members: +ImageSegmentationPipeline +======================================================================================================================= + +.. autoclass:: transformers.ImageSegmentationPipeline + :special-members: __call__ + :members: + NerPipeline ======================================================================================================================= diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index af64f83d5f..0cc3b21751 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -163,6 +163,13 @@ AutoModelForObjectDetection :members: +AutoModelForImageSegmentation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.AutoModelForImageSegmentation + :members: + + TFAutoModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a3ab3765d1..edb45fbc59 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -294,6 +294,7 @@ _import_structure = { "FeatureExtractionPipeline", "FillMaskPipeline", "ImageClassificationPipeline", + "ImageSegmentationPipeline", "JsonPipelineDataFormat", "NerPipeline", "ObjectDetectionPipeline", @@ -544,6 +545,7 @@ if is_torch_available(): "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -561,6 +563,7 @@ if is_torch_available(): "AutoModelForCausalLM", "AutoModelForCTC", "AutoModelForImageClassification", + "AutoModelForImageSegmentation", "AutoModelForMaskedLM", "AutoModelForMultipleChoice", "AutoModelForNextSentencePrediction", @@ -2113,6 +2116,7 @@ if TYPE_CHECKING: FeatureExtractionPipeline, FillMaskPipeline, ImageClassificationPipeline, + ImageSegmentationPipeline, JsonPipelineDataFormat, NerPipeline, ObjectDetectionPipeline, @@ -2320,6 +2324,7 @@ if TYPE_CHECKING: MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -2337,6 +2342,7 @@ if TYPE_CHECKING: AutoModelForCausalLM, AutoModelForCTC, AutoModelForImageClassification, + AutoModelForImageSegmentation, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForNextSentencePrediction, diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index 1c28c95f3a..0dfcc17a9d 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -45,6 +45,7 @@ from .models.auto.modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, @@ -60,6 +61,7 @@ from .utils import logging TASK_MAPPING = { "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + "image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES, "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, @@ -273,6 +275,7 @@ should probably proofread and complete it, then remove this comment. --> TASK_TAG_TO_NAME_MAPPING = { "fill-mask": "Masked Language Modeling", "image-classification": "Image Classification", + "image-segmentation": "Image Segmentation", "multiple-choice": "Multiple Choice", "object-detection": "Object Detection", "question-answering": "Question Answering", diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 1666f483a7..98133afee2 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -34,6 +34,7 @@ if is_torch_available(): "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -52,6 +53,7 @@ if is_torch_available(): "AutoModelForCausalLM", "AutoModelForCTC", "AutoModelForImageClassification", + "AutoModelForImageSegmentation", "AutoModelForMaskedLM", "AutoModelForMultipleChoice", "AutoModelForNextSentencePrediction", @@ -130,6 +132,7 @@ if TYPE_CHECKING: MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CTC_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -148,6 +151,7 @@ if TYPE_CHECKING: AutoModelForCausalLM, AutoModelForCTC, AutoModelForImageClassification, + AutoModelForImageSegmentation, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForNextSentencePrediction, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e246728d68..098305672a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -228,6 +228,13 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Image Segmentation mapping + ("detr", "DetrForSegmentation"), + ] +) + MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping @@ -484,6 +491,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_C MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ) +MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES +) MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( @@ -614,6 +624,13 @@ class AutoModelForImageClassification(_BaseAutoModelClass): AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") +class AutoModelForImageSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING + + +AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") + + class AutoModelForObjectDetection(_BaseAutoModelClass): _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING diff --git a/src/transformers/models/detr/feature_extraction_detr.py b/src/transformers/models/detr/feature_extraction_detr.py index 09962b5057..cbdd375289 100644 --- a/src/transformers/models/detr/feature_extraction_detr.py +++ b/src/transformers/models/detr/feature_extraction_detr.py @@ -713,8 +713,50 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): return results + def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5): + """ + Converts the output of :class:`~transformers.DetrForSegmentation` into image segmentation predictions. Only + supports PyTorch. + + Parameters: + outputs (:class:`~transformers.DetrSegmentationOutput`): + Raw outputs of the model. + target_sizes (:obj:`torch.Tensor` of shape :obj:`(batch_size, 2)` or :obj:`List[Tuple]` of length :obj:`batch_size`): + Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. + threshold (:obj:`float`, `optional`, defaults to 0.9): + Threshold to use to filter out queries. + mask_threshold (:obj:`float`, `optional`, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + + Returns: + :obj:`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an + image in the batch as predicted by the model. + """ + out_logits, raw_masks = outputs.logits, outputs.pred_masks + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.cpu().tolist()) + + for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes): + # we filter empty queries and detection below threshold + scores, labels = cur_logits.softmax(-1).max(-1) + keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold) + cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores = cur_scores[keep] + cur_classes = cur_classes[keep] + cur_masks = cur_masks[keep] + cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) + cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1 + + predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks} + preds.append(predictions) + return preds + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218 - def post_process_segmentation(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5): + def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5): """ Converts the output of :class:`~transformers.DetrForSegmentation` into actual instance segmentation predictions. Only supports PyTorch. diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 3d46a372dc..19bae7ff0d 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -44,6 +44,7 @@ from .conversational import Conversation, ConversationalPipeline from .feature_extraction import FeatureExtractionPipeline from .fill_mask import FillMaskPipeline from .image_classification import ImageClassificationPipeline +from .image_segmentation import ImageSegmentationPipeline from .object_detection import ObjectDetectionPipeline from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline @@ -92,6 +93,7 @@ if is_torch_available(): AutoModelForCausalLM, AutoModelForCTC, AutoModelForImageClassification, + AutoModelForImageSegmentation, AutoModelForMaskedLM, AutoModelForObjectDetection, AutoModelForQuestionAnswering, @@ -231,6 +233,12 @@ SUPPORTED_TASKS = { "pt": (AutoModelForImageClassification,) if is_torch_available() else (), "default": {"model": {"pt": "google/vit-base-patch16-224"}}, }, + "image-segmentation": { + "impl": ImageSegmentationPipeline, + "tf": (), + "pt": (AutoModelForImageSegmentation,) if is_torch_available() else (), + "default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}}, + }, "object-detection": { "impl": ObjectDetectionPipeline, "tf": (), diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py new file mode 100644 index 0000000000..9153a403e6 --- /dev/null +++ b/src/transformers/pipelines/image_segmentation.py @@ -0,0 +1,165 @@ +import base64 +import io +import os +from typing import Any, Dict, List, Union + +import numpy as np + +import requests + +from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends +from ..utils import logging +from .base import PIPELINE_INIT_ARGS, Pipeline + + +if is_vision_available(): + from PIL import Image + +if is_torch_available(): + import torch + + from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_SEGMENTATION_MAPPING + +logger = logging.get_logger(__name__) + + +Prediction = Dict[str, Any] +Predictions = List[Prediction] + + +@add_end_docstrings(PIPELINE_INIT_ARGS) +class ImageSegmentationPipeline(Pipeline): + """ + Image segmentation pipeline using any :obj:`AutoModelForImageSegmentation`. This pipeline predicts masks of objects + and their classes. + + This image segmntation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following + task identifier: :obj:`"image-segmentation"`. + + See the list of available models on `huggingface.co/models + `__. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.framework == "tf": + raise ValueError(f"The {self.__class__} is only available in PyTorch.") + + requires_backends(self, "vision") + self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING) + + @staticmethod + def load_image(image: Union[str, "Image.Image"]): + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + # We need to actually check for a real protocol, otherwise it's impossible to use a local file + # like http_huggingface_co.png + image = Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, Image.Image): + pass + else: + raise ValueError( + "Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image." + ) + image = image.convert("RGB") + return image + + def _sanitize_parameters(self, **kwargs): + postprocess_kwargs = {} + if "threshold" in kwargs: + postprocess_kwargs["threshold"] = kwargs["threshold"] + if "mask_threshold" in kwargs: + postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"] + return {}, {}, postprocess_kwargs + + def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]: + """ + Perform segmentation (detect masks & classes) in the image(s) passed as inputs. + + Args: + images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`): + The pipeline handles three types of images: + + - A string containing an HTTP(S) link pointing to an image + - A string containing a local path to an image + - An image loaded in PIL directly + + The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the + same format: all as HTTP(S) links, all as local paths, or all as PIL images. + threshold (:obj:`float`, `optional`, defaults to 0.9): + The probability necessary to make a prediction. + mask_threshold (:obj:`float`, `optional`, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + + Return: + A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a + dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to + each image. + + The dictionaries contain the following keys: + + - **label** (:obj:`str`) -- The class label identified by the model. + - **score** (:obj:`float`) -- The score attributed by the model for that label. + - **mask** (:obj:`str`) -- base64 string of a single-channel PNG image that contain masks information. The + PNG image has size (heigth, width) of the original image. Pixel values in the image are either 0 or 255 + (i.e. mask is absent VS mask is present). + """ + + return super().__call__(*args, **kwargs) + + def preprocess(self, image): + image = self.load_image(image) + target_size = torch.IntTensor([[image.height, image.width]]) + inputs = self.feature_extractor(images=[image], return_tensors="pt") + inputs["target_size"] = target_size + return inputs + + def _forward(self, model_inputs): + target_size = model_inputs.pop("target_size") + outputs = self.model(**model_inputs) + model_outputs = {"outputs": outputs, "target_size": target_size} + return model_outputs + + def postprocess(self, model_outputs, threshold=0.9, mask_threshold=0.5): + raw_annotations = self.feature_extractor.post_process_segmentation( + model_outputs["outputs"], model_outputs["target_size"], threshold=threshold, mask_threshold=0.5 + ) + raw_annotation = raw_annotations[0] + + raw_annotation["masks"] *= 255 # [0,1] -> [0,255] black and white pixels + + raw_annotation["scores"] = raw_annotation["scores"].tolist() + raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in raw_annotation["labels"]] + raw_annotation["masks"] = [self._get_mask_str(mask) for mask in raw_annotation["masks"].cpu().numpy()] + + # {"scores": [...], ...} --> [{"score":x, ...}, ...] + keys = ["score", "label", "mask"] + annotation = [ + dict(zip(keys, vals)) + for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["masks"]) + ] + + return annotation + + def _get_mask_str(self, mask: np.array) -> str: + """ + Turns mask numpy array into mask base64 str. + + Args: + mask (np.array): Numpy array (with shape (heigth, width) of the original image) containing masks information. Values in the array are either 0 or 255 (i.e. mask is absent VS mask is present). + + Returns: + A base64 string of a single-channel PNG image that contain masks information. + """ + img = Image.fromarray(mask.astype(np.int8)) + with io.BytesIO() as out: + img.save(out, format="PNG") + png_string = out.getvalue() + return base64.b64encode(png_string).decode("utf-8") diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c91cb17c0d..f43e4ad4c7 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -316,6 +316,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = None MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None +MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None + + MODEL_FOR_MASKED_LM_MAPPING = None @@ -397,6 +400,15 @@ class AutoModelForImageClassification: requires_backends(cls, ["torch"]) +class AutoModelForImageSegmentation: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoModelForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_pipelines_image_segmentation.py b/tests/test_pipelines_image_segmentation.py new file mode 100644 index 0000000000..9ccedda8a5 --- /dev/null +++ b/tests/test_pipelines_image_segmentation.py @@ -0,0 +1,241 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import unittest + +from transformers import ( + MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, + AutoFeatureExtractor, + AutoModelForImageSegmentation, + ImageSegmentationPipeline, + is_vision_available, + pipeline, +) +from transformers.testing_utils import ( + is_pipeline_test, + nested_simplify, + require_datasets, + require_tf, + require_timm, + require_torch, + require_vision, + slow, +) + +from .test_pipelines_common import ANY, PipelineTestCaseMeta + + +if is_vision_available(): + from PIL import Image +else: + + class Image: + @staticmethod + def open(*args, **kwargs): + pass + + +@require_vision +@require_timm +@require_torch +@is_pipeline_test +class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING + + @require_datasets + def run_pipeline_test(self, model, tokenizer, feature_extractor): + image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor) + outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0) + self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12) + + import datasets + + dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test") + + batch = [ + Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"), + "http://images.cocodataset.org/val2017/000000039769.jpg", + # RGBA + dataset[0]["file"], + # LA + dataset[1]["file"], + # L + dataset[2]["file"], + ] + outputs = image_segmenter(batch, threshold=0.0) + + self.assertEqual(len(batch), len(outputs)) + self.assertEqual( + outputs, + [ + [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12, + [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12, + [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12, + [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12, + [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12, + ], + ) + + @require_tf + @unittest.skip("Image segmentation not implemented in TF") + def test_small_model_tf(self): + pass + + @require_torch + def test_small_model_pt(self): + model_id = "mishig/tiny-detr-mobilenetsv3-panoptic" + + model = AutoModelForImageSegmentation.from_pretrained(model_id) + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) + image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor) + + outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=0.0) + for o in outputs: + # shortening by hashing + o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest() + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + { + "score": 0.004, + "label": "LABEL_0", + "mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc", + }, + { + "score": 0.004, + "label": "LABEL_0", + "mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc", + }, + ], + ) + + outputs = image_segmenter( + [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + ], + threshold=0.0, + ) + for output in outputs: + for o in output: + o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest() + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + { + "score": 0.004, + "label": "LABEL_0", + "mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc", + }, + { + "score": 0.004, + "label": "LABEL_0", + "mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc", + }, + ], + [ + { + "score": 0.004, + "label": "LABEL_0", + "mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc", + }, + { + "score": 0.004, + "label": "LABEL_0", + "mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc", + }, + ], + ], + ) + + @require_torch + @slow + def test_integration_torch_image_segmentation(self): + model_id = "facebook/detr-resnet-50-panoptic" + + image_segmenter = pipeline("image-segmentation", model=model_id) + + outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg") + for o in outputs: + o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest() + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"}, + {"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"}, + {"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"}, + {"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"}, + {"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"}, + {"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"}, + ], + ) + + outputs = image_segmenter( + [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + ], + threshold=0.0, + ) + for output in outputs: + for o in output: + o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest() + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"}, + {"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"}, + {"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"}, + {"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"}, + {"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"}, + {"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"}, + ], + [ + {"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"}, + {"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"}, + {"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"}, + {"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"}, + {"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"}, + {"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"}, + ], + ], + ) + + @require_torch + @slow + def test_threshold(self): + threshold = 0.999 + model_id = "facebook/detr-resnet-50-panoptic" + + image_segmenter = pipeline("image-segmentation", model=model_id) + + outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=threshold) + + for o in outputs: + o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest() + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"}, + {"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"}, + ], + )