Object detection pipeline (#12886)
* Implement object-detection pipeline * Define threshold const * Add `threshold` argument * Refactor * Uncomment test inputs * `rm Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Fix typo Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Fix typo Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Chore better doc Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Rm unnecessary lines Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Chore better naming Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Fix typo * Add `detr-tiny` for tests * Add `ObjectDetectionPipeline` to `trnsfrmrs/init` * Implement new bbox format * Update detr post_process * Update `load_img` method obj det pipeline * make style * Implement new testing format for obj det pipeln * Add guard pytorch specific code in pipeline * Add doc * Make pipeline_obj_tet tests deterministic * Revert some changes to `post_process` COCO api * Chore * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Rm timm requirement * make fixup * Add timm requirement to test * Make fixup * Guard torch.Tensor * Chore * Delete unnecessary comment Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,7 @@ There are two categories of pipeline abstractions to be aware about:
|
|||||||
- :class:`~transformers.FeatureExtractionPipeline`
|
- :class:`~transformers.FeatureExtractionPipeline`
|
||||||
- :class:`~transformers.FillMaskPipeline`
|
- :class:`~transformers.FillMaskPipeline`
|
||||||
- :class:`~transformers.ImageClassificationPipeline`
|
- :class:`~transformers.ImageClassificationPipeline`
|
||||||
|
- :class:`~transformers.ObjectDetectionPipeline`
|
||||||
- :class:`~transformers.QuestionAnsweringPipeline`
|
- :class:`~transformers.QuestionAnsweringPipeline`
|
||||||
- :class:`~transformers.SummarizationPipeline`
|
- :class:`~transformers.SummarizationPipeline`
|
||||||
- :class:`~transformers.TableQuestionAnsweringPipeline`
|
- :class:`~transformers.TableQuestionAnsweringPipeline`
|
||||||
@@ -102,6 +103,13 @@ NerPipeline
|
|||||||
|
|
||||||
See :class:`~transformers.TokenClassificationPipeline` for all details.
|
See :class:`~transformers.TokenClassificationPipeline` for all details.
|
||||||
|
|
||||||
|
ObjectDetectionPipeline
|
||||||
|
=======================================================================================================================
|
||||||
|
|
||||||
|
.. autoclass:: transformers.ObjectDetectionPipeline
|
||||||
|
:special-members: __call__
|
||||||
|
:members:
|
||||||
|
|
||||||
QuestionAnsweringPipeline
|
QuestionAnsweringPipeline
|
||||||
=======================================================================================================================
|
=======================================================================================================================
|
||||||
|
|
||||||
|
|||||||
@@ -142,6 +142,13 @@ AutoModelForAudioClassification
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForObjectDetection
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.AutoModelForObjectDetection
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFAutoModel
|
TFAutoModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ _import_structure = {
|
|||||||
"ImageClassificationPipeline",
|
"ImageClassificationPipeline",
|
||||||
"JsonPipelineDataFormat",
|
"JsonPipelineDataFormat",
|
||||||
"NerPipeline",
|
"NerPipeline",
|
||||||
|
"ObjectDetectionPipeline",
|
||||||
"PipedPipelineDataFormat",
|
"PipedPipelineDataFormat",
|
||||||
"Pipeline",
|
"Pipeline",
|
||||||
"PipelineDataFormat",
|
"PipelineDataFormat",
|
||||||
@@ -558,6 +559,7 @@ if is_torch_available():
|
|||||||
"AutoModelForMaskedLM",
|
"AutoModelForMaskedLM",
|
||||||
"AutoModelForMultipleChoice",
|
"AutoModelForMultipleChoice",
|
||||||
"AutoModelForNextSentencePrediction",
|
"AutoModelForNextSentencePrediction",
|
||||||
|
"AutoModelForObjectDetection",
|
||||||
"AutoModelForPreTraining",
|
"AutoModelForPreTraining",
|
||||||
"AutoModelForQuestionAnswering",
|
"AutoModelForQuestionAnswering",
|
||||||
"AutoModelForSeq2SeqLM",
|
"AutoModelForSeq2SeqLM",
|
||||||
@@ -2074,6 +2076,7 @@ if TYPE_CHECKING:
|
|||||||
ImageClassificationPipeline,
|
ImageClassificationPipeline,
|
||||||
JsonPipelineDataFormat,
|
JsonPipelineDataFormat,
|
||||||
NerPipeline,
|
NerPipeline,
|
||||||
|
ObjectDetectionPipeline,
|
||||||
PipedPipelineDataFormat,
|
PipedPipelineDataFormat,
|
||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineDataFormat,
|
PipelineDataFormat,
|
||||||
@@ -2295,6 +2298,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
AutoModelForNextSentencePrediction,
|
AutoModelForNextSentencePrediction,
|
||||||
|
AutoModelForObjectDetection,
|
||||||
AutoModelForPreTraining,
|
AutoModelForPreTraining,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ if is_torch_available():
|
|||||||
"AutoModelForMaskedLM",
|
"AutoModelForMaskedLM",
|
||||||
"AutoModelForMultipleChoice",
|
"AutoModelForMultipleChoice",
|
||||||
"AutoModelForNextSentencePrediction",
|
"AutoModelForNextSentencePrediction",
|
||||||
|
"AutoModelForObjectDetection",
|
||||||
"AutoModelForPreTraining",
|
"AutoModelForPreTraining",
|
||||||
"AutoModelForQuestionAnswering",
|
"AutoModelForQuestionAnswering",
|
||||||
"AutoModelForSeq2SeqLM",
|
"AutoModelForSeq2SeqLM",
|
||||||
@@ -143,6 +144,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
AutoModelForNextSentencePrediction,
|
AutoModelForNextSentencePrediction,
|
||||||
|
AutoModelForObjectDetection,
|
||||||
AutoModelForPreTraining,
|
AutoModelForPreTraining,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from .configuration_auto import (
|
|||||||
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
("beit", "BeitFeatureExtractor"),
|
("beit", "BeitFeatureExtractor"),
|
||||||
|
("detr", "DetrFeatureExtractor"),
|
||||||
("deit", "DeiTFeatureExtractor"),
|
("deit", "DeiTFeatureExtractor"),
|
||||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||||
|
|||||||
@@ -588,6 +588,13 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
|
|||||||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForObjectDetection(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForAudioClassification(_BaseAutoModelClass):
|
class AutoModelForAudioClassification(_BaseAutoModelClass):
|
||||||
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from .conversational import Conversation, ConversationalPipeline
|
|||||||
from .feature_extraction import FeatureExtractionPipeline
|
from .feature_extraction import FeatureExtractionPipeline
|
||||||
from .fill_mask import FillMaskPipeline
|
from .fill_mask import FillMaskPipeline
|
||||||
from .image_classification import ImageClassificationPipeline
|
from .image_classification import ImageClassificationPipeline
|
||||||
|
from .object_detection import ObjectDetectionPipeline
|
||||||
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
||||||
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
||||||
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
|
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
|
||||||
@@ -91,6 +92,7 @@ if is_torch_available():
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
|
AutoModelForObjectDetection,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
@@ -229,6 +231,12 @@ SUPPORTED_TASKS = {
|
|||||||
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
||||||
},
|
},
|
||||||
|
"object-detection": {
|
||||||
|
"impl": ObjectDetectionPipeline,
|
||||||
|
"tf": (),
|
||||||
|
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
|
||||||
|
"default": {"model": {"pt": "facebook/detr-resnet-50"}},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
176
src/transformers/pipelines/object_detection.py
Normal file
176
src/transformers/pipelines/object_detection.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||||
|
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||||
|
from ..utils import logging
|
||||||
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
Prediction = Dict[str, Any]
|
||||||
|
Predictions = List[Prediction]
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
|
class ObjectDetectionPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Object detection pipeline using any :obj:`AutoModelForObjectDetection`. This pipeline predicts bounding boxes of
|
||||||
|
objects and their classes.
|
||||||
|
|
||||||
|
This object detection pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
|
||||||
|
identifier: :obj:`"object-detection"`.
|
||||||
|
|
||||||
|
See the list of available models on `huggingface.co/models
|
||||||
|
<https://huggingface.co/models?filter=object-detection>`__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
feature_extractor: PreTrainedFeatureExtractor,
|
||||||
|
framework: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
|
||||||
|
|
||||||
|
if self.framework == "tf":
|
||||||
|
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||||
|
|
||||||
|
requires_backends(self, "vision")
|
||||||
|
|
||||||
|
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
|
||||||
|
|
||||||
|
self.feature_extractor = feature_extractor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_image(image: Union[str, "Image.Image"]):
|
||||||
|
if isinstance(image, str):
|
||||||
|
if image.startswith("http://") or image.startswith("https://"):
|
||||||
|
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||||
|
# like http_huggingface_co.png
|
||||||
|
image = Image.open(requests.get(image, stream=True).raw)
|
||||||
|
elif os.path.isfile(image):
|
||||||
|
image = Image.open(image)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||||
|
)
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
|
||||||
|
)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: Union[str, List[str], "Image", List["Image"]],
|
||||||
|
threshold: Optional[float] = 0.9,
|
||||||
|
) -> Union[Predictions, List[Prediction]]:
|
||||||
|
"""
|
||||||
|
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`):
|
||||||
|
The pipeline handles three types of images:
|
||||||
|
|
||||||
|
- A string containing an HTTP(S) link pointing to an image
|
||||||
|
- A string containing a local path to an image
|
||||||
|
- An image loaded in PIL directly
|
||||||
|
|
||||||
|
The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
|
||||||
|
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
||||||
|
threshold (:obj:`float`, `optional`, defaults to 0.9):
|
||||||
|
The probability necessary to make a prediction.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
|
||||||
|
image, will return a list of dictionaries, if the input is a list of several images, will return a list of
|
||||||
|
list of dictionaries corresponding to each image.
|
||||||
|
|
||||||
|
The dictionaries contain the following keys:
|
||||||
|
|
||||||
|
- **label** (:obj:`str`) -- The class label identified by the model.
|
||||||
|
- **score** (:obj:`float`) -- The score attributed by the model for that label.
|
||||||
|
- **box** (:obj:`List[Dict[str, int]]`) -- The bounding box of detected object in image's original size.
|
||||||
|
"""
|
||||||
|
is_batched = isinstance(images, list)
|
||||||
|
|
||||||
|
if not is_batched:
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
images = [self.load_image(image) for image in images]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = self.feature_extractor(images=images, return_tensors="pt")
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
|
if self.framework == "pt":
|
||||||
|
target_sizes = torch.IntTensor([[im.height, im.width] for im in images])
|
||||||
|
else:
|
||||||
|
raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.")
|
||||||
|
|
||||||
|
raw_annotations = self.feature_extractor.post_process(outputs, target_sizes)
|
||||||
|
annotations = []
|
||||||
|
for annotation in raw_annotations:
|
||||||
|
keep = annotation["scores"] > threshold
|
||||||
|
scores = annotation["scores"][keep]
|
||||||
|
labels = annotation["labels"][keep]
|
||||||
|
boxes = annotation["boxes"][keep]
|
||||||
|
|
||||||
|
annotation["scores"] = scores.tolist()
|
||||||
|
annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
|
||||||
|
annotation["boxes"] = [self._get_bounding_box(box) for box in boxes]
|
||||||
|
|
||||||
|
# {"scores": [...], ...} --> [{"score":x, ...}, ...]
|
||||||
|
keys = ["score", "label", "box"]
|
||||||
|
annotation = [
|
||||||
|
dict(zip(keys, vals))
|
||||||
|
for vals in zip(annotation["scores"], annotation["labels"], annotation["boxes"])
|
||||||
|
]
|
||||||
|
|
||||||
|
annotations.append(annotation)
|
||||||
|
|
||||||
|
if not is_batched:
|
||||||
|
return annotations[0]
|
||||||
|
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
|
||||||
|
"""
|
||||||
|
Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
box (torch.Tensor): Tensor containing the coordinates in corners format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bbox (Dict[str, int]): Dict containing the coordinates in corners format.
|
||||||
|
"""
|
||||||
|
if self.framework != "pt":
|
||||||
|
raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.")
|
||||||
|
xmin, ymin, xmax, ymax = box.int().tolist()
|
||||||
|
bbox = {
|
||||||
|
"xmin": xmin,
|
||||||
|
"ymin": ymin,
|
||||||
|
"xmax": xmax,
|
||||||
|
"ymax": ymax,
|
||||||
|
}
|
||||||
|
return bbox
|
||||||
@@ -415,6 +415,15 @@ class AutoModelForNextSentencePrediction:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForObjectDetection:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForPreTraining:
|
class AutoModelForPreTraining:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|||||||
253
tests/test_pipelines_object_detection.py
Normal file
253
tests/test_pipelines_object_detection.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
AutoModelForObjectDetection,
|
||||||
|
ObjectDetectionPipeline,
|
||||||
|
is_vision_available,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
nested_simplify,
|
||||||
|
require_datasets,
|
||||||
|
require_tf,
|
||||||
|
require_timm,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Image:
|
||||||
|
@staticmethod
|
||||||
|
def open(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_timm
|
||||||
|
@require_torch
|
||||||
|
@is_pipeline_test
|
||||||
|
class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
|
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||||
|
|
||||||
|
@require_datasets
|
||||||
|
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||||
|
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||||
|
outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||||
|
|
||||||
|
self.assertGreater(len(outputs), 0)
|
||||||
|
for detected_object in outputs:
|
||||||
|
self.assertEqual(
|
||||||
|
detected_object,
|
||||||
|
{
|
||||||
|
"score": ANY(float),
|
||||||
|
"label": ANY(str),
|
||||||
|
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test")
|
||||||
|
|
||||||
|
batch = [
|
||||||
|
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
# RGBA
|
||||||
|
dataset[0]["file"],
|
||||||
|
# LA
|
||||||
|
dataset[1]["file"],
|
||||||
|
# L
|
||||||
|
dataset[2]["file"],
|
||||||
|
]
|
||||||
|
batch_outputs = object_detector(batch, threshold=0.0)
|
||||||
|
|
||||||
|
self.assertEqual(len(batch), len(batch_outputs))
|
||||||
|
for outputs in batch_outputs:
|
||||||
|
self.assertGreater(len(outputs), 0)
|
||||||
|
for detected_object in outputs:
|
||||||
|
self.assertEqual(
|
||||||
|
detected_object,
|
||||||
|
{
|
||||||
|
"score": ANY(float),
|
||||||
|
"label": ANY(str),
|
||||||
|
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@unittest.skip("Object detection not implemented in TF")
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
model_id = "mishig/tiny-detr-mobilenetsv3"
|
||||||
|
|
||||||
|
model = AutoModelForObjectDetection.from_pretrained(model_id)
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||||
|
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=0.0)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||||
|
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
[
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
],
|
||||||
|
threshold=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||||
|
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||||
|
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_large_model_pt(self):
|
||||||
|
model_id = "facebook/detr-resnet-50"
|
||||||
|
|
||||||
|
model = AutoModelForObjectDetection.from_pretrained(model_id)
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||||
|
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||||
|
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||||
|
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||||
|
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||||
|
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
[
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||||
|
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||||
|
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||||
|
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||||
|
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||||
|
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||||
|
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||||
|
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||||
|
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_integration_torch_object_detection(self):
|
||||||
|
model_id = "facebook/detr-resnet-50"
|
||||||
|
|
||||||
|
object_detector = pipeline("object-detection", model=model_id)
|
||||||
|
|
||||||
|
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||||
|
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||||
|
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||||
|
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||||
|
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
[
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||||
|
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||||
|
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||||
|
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||||
|
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||||
|
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||||
|
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||||
|
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||||
|
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_threshold(self):
|
||||||
|
threshold = 0.9985
|
||||||
|
model_id = "facebook/detr-resnet-50"
|
||||||
|
|
||||||
|
object_detector = pipeline("object-detection", model=model_id)
|
||||||
|
|
||||||
|
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=threshold)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||||
|
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||||
|
],
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user