Image Segmentation pipeline (#13828)
* Implement img seg pipeline * Update src/transformers/pipelines/image_segmentation.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/image_segmentation.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update output shape with individual masks * Rm dev change * Remove loops in test Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,7 @@ There are two categories of pipeline abstractions to be aware about:
|
|||||||
- :class:`~transformers.FeatureExtractionPipeline`
|
- :class:`~transformers.FeatureExtractionPipeline`
|
||||||
- :class:`~transformers.FillMaskPipeline`
|
- :class:`~transformers.FillMaskPipeline`
|
||||||
- :class:`~transformers.ImageClassificationPipeline`
|
- :class:`~transformers.ImageClassificationPipeline`
|
||||||
|
- :class:`~transformers.ImageSegmentationPipeline`
|
||||||
- :class:`~transformers.ObjectDetectionPipeline`
|
- :class:`~transformers.ObjectDetectionPipeline`
|
||||||
- :class:`~transformers.QuestionAnsweringPipeline`
|
- :class:`~transformers.QuestionAnsweringPipeline`
|
||||||
- :class:`~transformers.SummarizationPipeline`
|
- :class:`~transformers.SummarizationPipeline`
|
||||||
@@ -137,6 +138,13 @@ ImageClassificationPipeline
|
|||||||
:special-members: __call__
|
:special-members: __call__
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
ImageSegmentationPipeline
|
||||||
|
=======================================================================================================================
|
||||||
|
|
||||||
|
.. autoclass:: transformers.ImageSegmentationPipeline
|
||||||
|
:special-members: __call__
|
||||||
|
:members:
|
||||||
|
|
||||||
NerPipeline
|
NerPipeline
|
||||||
=======================================================================================================================
|
=======================================================================================================================
|
||||||
|
|
||||||
|
|||||||
@@ -163,6 +163,13 @@ AutoModelForObjectDetection
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForImageSegmentation
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.AutoModelForImageSegmentation
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFAutoModel
|
TFAutoModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ _import_structure = {
|
|||||||
"FeatureExtractionPipeline",
|
"FeatureExtractionPipeline",
|
||||||
"FillMaskPipeline",
|
"FillMaskPipeline",
|
||||||
"ImageClassificationPipeline",
|
"ImageClassificationPipeline",
|
||||||
|
"ImageSegmentationPipeline",
|
||||||
"JsonPipelineDataFormat",
|
"JsonPipelineDataFormat",
|
||||||
"NerPipeline",
|
"NerPipeline",
|
||||||
"ObjectDetectionPipeline",
|
"ObjectDetectionPipeline",
|
||||||
@@ -544,6 +545,7 @@ if is_torch_available():
|
|||||||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
||||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||||
|
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
@@ -561,6 +563,7 @@ if is_torch_available():
|
|||||||
"AutoModelForCausalLM",
|
"AutoModelForCausalLM",
|
||||||
"AutoModelForCTC",
|
"AutoModelForCTC",
|
||||||
"AutoModelForImageClassification",
|
"AutoModelForImageClassification",
|
||||||
|
"AutoModelForImageSegmentation",
|
||||||
"AutoModelForMaskedLM",
|
"AutoModelForMaskedLM",
|
||||||
"AutoModelForMultipleChoice",
|
"AutoModelForMultipleChoice",
|
||||||
"AutoModelForNextSentencePrediction",
|
"AutoModelForNextSentencePrediction",
|
||||||
@@ -2113,6 +2116,7 @@ if TYPE_CHECKING:
|
|||||||
FeatureExtractionPipeline,
|
FeatureExtractionPipeline,
|
||||||
FillMaskPipeline,
|
FillMaskPipeline,
|
||||||
ImageClassificationPipeline,
|
ImageClassificationPipeline,
|
||||||
|
ImageSegmentationPipeline,
|
||||||
JsonPipelineDataFormat,
|
JsonPipelineDataFormat,
|
||||||
NerPipeline,
|
NerPipeline,
|
||||||
ObjectDetectionPipeline,
|
ObjectDetectionPipeline,
|
||||||
@@ -2320,6 +2324,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
@@ -2337,6 +2342,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForCTC,
|
AutoModelForCTC,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
|
AutoModelForImageSegmentation,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
AutoModelForNextSentencePrediction,
|
AutoModelForNextSentencePrediction,
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from .models.auto.modeling_auto import (
|
|||||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||||
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||||
@@ -60,6 +61,7 @@ from .utils import logging
|
|||||||
TASK_MAPPING = {
|
TASK_MAPPING = {
|
||||||
"text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
"text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
"image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
|
||||||
"fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
"fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||||
"object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
"object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
||||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||||
@@ -273,6 +275,7 @@ should probably proofread and complete it, then remove this comment. -->
|
|||||||
TASK_TAG_TO_NAME_MAPPING = {
|
TASK_TAG_TO_NAME_MAPPING = {
|
||||||
"fill-mask": "Masked Language Modeling",
|
"fill-mask": "Masked Language Modeling",
|
||||||
"image-classification": "Image Classification",
|
"image-classification": "Image Classification",
|
||||||
|
"image-segmentation": "Image Segmentation",
|
||||||
"multiple-choice": "Multiple Choice",
|
"multiple-choice": "Multiple Choice",
|
||||||
"object-detection": "Object Detection",
|
"object-detection": "Object Detection",
|
||||||
"question-answering": "Question Answering",
|
"question-answering": "Question Answering",
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ if is_torch_available():
|
|||||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||||
"MODEL_FOR_CTC_MAPPING",
|
"MODEL_FOR_CTC_MAPPING",
|
||||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||||
|
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
@@ -52,6 +53,7 @@ if is_torch_available():
|
|||||||
"AutoModelForCausalLM",
|
"AutoModelForCausalLM",
|
||||||
"AutoModelForCTC",
|
"AutoModelForCTC",
|
||||||
"AutoModelForImageClassification",
|
"AutoModelForImageClassification",
|
||||||
|
"AutoModelForImageSegmentation",
|
||||||
"AutoModelForMaskedLM",
|
"AutoModelForMaskedLM",
|
||||||
"AutoModelForMultipleChoice",
|
"AutoModelForMultipleChoice",
|
||||||
"AutoModelForNextSentencePrediction",
|
"AutoModelForNextSentencePrediction",
|
||||||
@@ -130,6 +132,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_CTC_MAPPING,
|
MODEL_FOR_CTC_MAPPING,
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
@@ -148,6 +151,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForCTC,
|
AutoModelForCTC,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
|
AutoModelForImageSegmentation,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
AutoModelForNextSentencePrediction,
|
AutoModelForNextSentencePrediction,
|
||||||
|
|||||||
@@ -228,6 +228,13 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Image Segmentation mapping
|
||||||
|
("detr", "DetrForSegmentation"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Masked LM mapping
|
# Model for Masked LM mapping
|
||||||
@@ -484,6 +491,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_C
|
|||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
|
||||||
|
)
|
||||||
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||||
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
|
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||||
@@ -614,6 +624,13 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
|
|||||||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForImageSegmentation(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForObjectDetection(_BaseAutoModelClass):
|
class AutoModelForObjectDetection(_BaseAutoModelClass):
|
||||||
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -713,8 +713,50 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
|
||||||
|
"""
|
||||||
|
Converts the output of :class:`~transformers.DetrForSegmentation` into image segmentation predictions. Only
|
||||||
|
supports PyTorch.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
outputs (:class:`~transformers.DetrSegmentationOutput`):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (:obj:`torch.Tensor` of shape :obj:`(batch_size, 2)` or :obj:`List[Tuple]` of length :obj:`batch_size`):
|
||||||
|
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
|
||||||
|
threshold (:obj:`float`, `optional`, defaults to 0.9):
|
||||||
|
Threshold to use to filter out queries.
|
||||||
|
mask_threshold (:obj:`float`, `optional`, defaults to 0.5):
|
||||||
|
Threshold to use when turning the predicted masks into binary values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an
|
||||||
|
image in the batch as predicted by the model.
|
||||||
|
"""
|
||||||
|
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
||||||
|
preds = []
|
||||||
|
|
||||||
|
def to_tuple(tup):
|
||||||
|
if isinstance(tup, tuple):
|
||||||
|
return tup
|
||||||
|
return tuple(tup.cpu().tolist())
|
||||||
|
|
||||||
|
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
||||||
|
# we filter empty queries and detection below threshold
|
||||||
|
scores, labels = cur_logits.softmax(-1).max(-1)
|
||||||
|
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
|
||||||
|
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
||||||
|
cur_scores = cur_scores[keep]
|
||||||
|
cur_classes = cur_classes[keep]
|
||||||
|
cur_masks = cur_masks[keep]
|
||||||
|
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
||||||
|
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
||||||
|
|
||||||
|
predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks}
|
||||||
|
preds.append(predictions)
|
||||||
|
return preds
|
||||||
|
|
||||||
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
|
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
|
||||||
def post_process_segmentation(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
||||||
"""
|
"""
|
||||||
Converts the output of :class:`~transformers.DetrForSegmentation` into actual instance segmentation
|
Converts the output of :class:`~transformers.DetrForSegmentation` into actual instance segmentation
|
||||||
predictions. Only supports PyTorch.
|
predictions. Only supports PyTorch.
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from .conversational import Conversation, ConversationalPipeline
|
|||||||
from .feature_extraction import FeatureExtractionPipeline
|
from .feature_extraction import FeatureExtractionPipeline
|
||||||
from .fill_mask import FillMaskPipeline
|
from .fill_mask import FillMaskPipeline
|
||||||
from .image_classification import ImageClassificationPipeline
|
from .image_classification import ImageClassificationPipeline
|
||||||
|
from .image_segmentation import ImageSegmentationPipeline
|
||||||
from .object_detection import ObjectDetectionPipeline
|
from .object_detection import ObjectDetectionPipeline
|
||||||
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
||||||
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
||||||
@@ -92,6 +93,7 @@ if is_torch_available():
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForCTC,
|
AutoModelForCTC,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
|
AutoModelForImageSegmentation,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForObjectDetection,
|
AutoModelForObjectDetection,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
@@ -231,6 +233,12 @@ SUPPORTED_TASKS = {
|
|||||||
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
||||||
},
|
},
|
||||||
|
"image-segmentation": {
|
||||||
|
"impl": ImageSegmentationPipeline,
|
||||||
|
"tf": (),
|
||||||
|
"pt": (AutoModelForImageSegmentation,) if is_torch_available() else (),
|
||||||
|
"default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
|
||||||
|
},
|
||||||
"object-detection": {
|
"object-detection": {
|
||||||
"impl": ObjectDetectionPipeline,
|
"impl": ObjectDetectionPipeline,
|
||||||
"tf": (),
|
"tf": (),
|
||||||
|
|||||||
165
src/transformers/pipelines/image_segmentation.py
Normal file
165
src/transformers/pipelines/image_segmentation.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||||
|
from ..utils import logging
|
||||||
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
Prediction = Dict[str, Any]
|
||||||
|
Predictions = List[Prediction]
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
|
class ImageSegmentationPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Image segmentation pipeline using any :obj:`AutoModelForImageSegmentation`. This pipeline predicts masks of objects
|
||||||
|
and their classes.
|
||||||
|
|
||||||
|
This image segmntation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
|
||||||
|
task identifier: :obj:`"image-segmentation"`.
|
||||||
|
|
||||||
|
See the list of available models on `huggingface.co/models
|
||||||
|
<https://huggingface.co/models?filter=image-segmentation>`__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.framework == "tf":
|
||||||
|
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||||
|
|
||||||
|
requires_backends(self, "vision")
|
||||||
|
self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_image(image: Union[str, "Image.Image"]):
|
||||||
|
if isinstance(image, str):
|
||||||
|
if image.startswith("http://") or image.startswith("https://"):
|
||||||
|
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||||
|
# like http_huggingface_co.png
|
||||||
|
image = Image.open(requests.get(image, stream=True).raw)
|
||||||
|
elif os.path.isfile(image):
|
||||||
|
image = Image.open(image)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||||
|
)
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
|
||||||
|
)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _sanitize_parameters(self, **kwargs):
|
||||||
|
postprocess_kwargs = {}
|
||||||
|
if "threshold" in kwargs:
|
||||||
|
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
||||||
|
if "mask_threshold" in kwargs:
|
||||||
|
postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
|
||||||
|
return {}, {}, postprocess_kwargs
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||||
|
"""
|
||||||
|
Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`):
|
||||||
|
The pipeline handles three types of images:
|
||||||
|
|
||||||
|
- A string containing an HTTP(S) link pointing to an image
|
||||||
|
- A string containing a local path to an image
|
||||||
|
- An image loaded in PIL directly
|
||||||
|
|
||||||
|
The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
|
||||||
|
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
||||||
|
threshold (:obj:`float`, `optional`, defaults to 0.9):
|
||||||
|
The probability necessary to make a prediction.
|
||||||
|
mask_threshold (:obj:`float`, `optional`, defaults to 0.5):
|
||||||
|
Threshold to use when turning the predicted masks into binary values.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
|
||||||
|
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
|
||||||
|
each image.
|
||||||
|
|
||||||
|
The dictionaries contain the following keys:
|
||||||
|
|
||||||
|
- **label** (:obj:`str`) -- The class label identified by the model.
|
||||||
|
- **score** (:obj:`float`) -- The score attributed by the model for that label.
|
||||||
|
- **mask** (:obj:`str`) -- base64 string of a single-channel PNG image that contain masks information. The
|
||||||
|
PNG image has size (heigth, width) of the original image. Pixel values in the image are either 0 or 255
|
||||||
|
(i.e. mask is absent VS mask is present).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return super().__call__(*args, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(self, image):
|
||||||
|
image = self.load_image(image)
|
||||||
|
target_size = torch.IntTensor([[image.height, image.width]])
|
||||||
|
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
||||||
|
inputs["target_size"] = target_size
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
target_size = model_inputs.pop("target_size")
|
||||||
|
outputs = self.model(**model_inputs)
|
||||||
|
model_outputs = {"outputs": outputs, "target_size": target_size}
|
||||||
|
return model_outputs
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs, threshold=0.9, mask_threshold=0.5):
|
||||||
|
raw_annotations = self.feature_extractor.post_process_segmentation(
|
||||||
|
model_outputs["outputs"], model_outputs["target_size"], threshold=threshold, mask_threshold=0.5
|
||||||
|
)
|
||||||
|
raw_annotation = raw_annotations[0]
|
||||||
|
|
||||||
|
raw_annotation["masks"] *= 255 # [0,1] -> [0,255] black and white pixels
|
||||||
|
|
||||||
|
raw_annotation["scores"] = raw_annotation["scores"].tolist()
|
||||||
|
raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in raw_annotation["labels"]]
|
||||||
|
raw_annotation["masks"] = [self._get_mask_str(mask) for mask in raw_annotation["masks"].cpu().numpy()]
|
||||||
|
|
||||||
|
# {"scores": [...], ...} --> [{"score":x, ...}, ...]
|
||||||
|
keys = ["score", "label", "mask"]
|
||||||
|
annotation = [
|
||||||
|
dict(zip(keys, vals))
|
||||||
|
for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["masks"])
|
||||||
|
]
|
||||||
|
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
def _get_mask_str(self, mask: np.array) -> str:
|
||||||
|
"""
|
||||||
|
Turns mask numpy array into mask base64 str.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask (np.array): Numpy array (with shape (heigth, width) of the original image) containing masks information. Values in the array are either 0 or 255 (i.e. mask is absent VS mask is present).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A base64 string of a single-channel PNG image that contain masks information.
|
||||||
|
"""
|
||||||
|
img = Image.fromarray(mask.astype(np.int8))
|
||||||
|
with io.BytesIO() as out:
|
||||||
|
img.save(out, format="PNG")
|
||||||
|
png_string = out.getvalue()
|
||||||
|
return base64.b64encode(png_string).decode("utf-8")
|
||||||
@@ -316,6 +316,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = None
|
|||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_MASKED_LM_MAPPING = None
|
MODEL_FOR_MASKED_LM_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -397,6 +400,15 @@ class AutoModelForImageClassification:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForImageSegmentation:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForMaskedLM:
|
class AutoModelForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|||||||
241
tests/test_pipelines_image_segmentation.py
Normal file
241
tests/test_pipelines_image_segmentation.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
AutoModelForImageSegmentation,
|
||||||
|
ImageSegmentationPipeline,
|
||||||
|
is_vision_available,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
nested_simplify,
|
||||||
|
require_datasets,
|
||||||
|
require_tf,
|
||||||
|
require_timm,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Image:
|
||||||
|
@staticmethod
|
||||||
|
def open(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_timm
|
||||||
|
@require_torch
|
||||||
|
@is_pipeline_test
|
||||||
|
class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
|
model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||||
|
|
||||||
|
@require_datasets
|
||||||
|
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||||
|
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
||||||
|
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||||
|
self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test")
|
||||||
|
|
||||||
|
batch = [
|
||||||
|
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
# RGBA
|
||||||
|
dataset[0]["file"],
|
||||||
|
# LA
|
||||||
|
dataset[1]["file"],
|
||||||
|
# L
|
||||||
|
dataset[2]["file"],
|
||||||
|
]
|
||||||
|
outputs = image_segmenter(batch, threshold=0.0)
|
||||||
|
|
||||||
|
self.assertEqual(len(batch), len(outputs))
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||||
|
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||||
|
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||||
|
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||||
|
[{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@unittest.skip("Image segmentation not implemented in TF")
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
model_id = "mishig/tiny-detr-mobilenetsv3-panoptic"
|
||||||
|
|
||||||
|
model = AutoModelForImageSegmentation.from_pretrained(model_id)
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||||
|
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=0.0)
|
||||||
|
for o in outputs:
|
||||||
|
# shortening by hashing
|
||||||
|
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"score": 0.004,
|
||||||
|
"label": "LABEL_0",
|
||||||
|
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"score": 0.004,
|
||||||
|
"label": "LABEL_0",
|
||||||
|
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = image_segmenter(
|
||||||
|
[
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
],
|
||||||
|
threshold=0.0,
|
||||||
|
)
|
||||||
|
for output in outputs:
|
||||||
|
for o in output:
|
||||||
|
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"score": 0.004,
|
||||||
|
"label": "LABEL_0",
|
||||||
|
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"score": 0.004,
|
||||||
|
"label": "LABEL_0",
|
||||||
|
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"score": 0.004,
|
||||||
|
"label": "LABEL_0",
|
||||||
|
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"score": 0.004,
|
||||||
|
"label": "LABEL_0",
|
||||||
|
"mask": "8423ef82b9a8e8790346bc452b557aa78ea997bc",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_integration_torch_image_segmentation(self):
|
||||||
|
model_id = "facebook/detr-resnet-50-panoptic"
|
||||||
|
|
||||||
|
image_segmenter = pipeline("image-segmentation", model=model_id)
|
||||||
|
|
||||||
|
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
for o in outputs:
|
||||||
|
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"},
|
||||||
|
{"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"},
|
||||||
|
{"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"},
|
||||||
|
{"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"},
|
||||||
|
{"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"},
|
||||||
|
{"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = image_segmenter(
|
||||||
|
[
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
],
|
||||||
|
threshold=0.0,
|
||||||
|
)
|
||||||
|
for output in outputs:
|
||||||
|
for o in output:
|
||||||
|
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"},
|
||||||
|
{"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"},
|
||||||
|
{"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"},
|
||||||
|
{"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"},
|
||||||
|
{"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"},
|
||||||
|
{"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.9094, "label": "blanket", "mask": "f939d943609821ad27cdb92844f2754ad3735b52"},
|
||||||
|
{"score": 0.9941, "label": "cat", "mask": "32913606de3958812ced0090df7b699abb6e2644"},
|
||||||
|
{"score": 0.9987, "label": "remote", "mask": "f3988d35f3065f591fa6a0a9414614d98a9ca13e"},
|
||||||
|
{"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"},
|
||||||
|
{"score": 0.9722, "label": "couch", "mask": "543c3244b291c4aec134f1d8f92af553da795529"},
|
||||||
|
{"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_threshold(self):
|
||||||
|
threshold = 0.999
|
||||||
|
model_id = "facebook/detr-resnet-50-panoptic"
|
||||||
|
|
||||||
|
image_segmenter = pipeline("image-segmentation", model=model_id)
|
||||||
|
|
||||||
|
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=threshold)
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
o["mask"] = hashlib.sha1(o["mask"].encode("UTF-8")).hexdigest()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.9995, "label": "remote", "mask": "ff0d541ace4fe386fc14ced0c546490a8e7001d7"},
|
||||||
|
{"score": 0.9994, "label": "cat", "mask": "891313e21290200e6169613e6a9cb7aff9e7b22f"},
|
||||||
|
],
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user