* Add ZeroShotObjectDetectionPipeline (#18445) * Add AutoModelForZeroShotObjectDetection task This commit also adds the following - Add explicit _processor method for ZeroShotObjectDetectionPipeline. This is necessary as pipelines don't auto infer processors yet and `OwlVitProcessor` wraps tokenizer and feature_extractor together, to process multiple images at once - Add auto tests and other tests for ZeroShotObjectDetectionPipeline * Add AutoModelForZeroShotObjectDetection task This commit also adds the following - Add explicit _processor method for ZeroShotObjectDetectionPipeline. This is necessary as pipelines don't auto infer processors yet and `OwlVitProcessor` wraps tokenizer and feature_extractor together, to process multiple images at once - Add auto tests and other tests for ZeroShotObjectDetectionPipeline * Add batching for ZeroShotObjectDetectionPipeline * Fix doc-string ZeroShotObjectDetectionPipeline * Fix output format: ZeroShotObjectDetectionPipeline
This commit is contained in:
@@ -43,6 +43,7 @@ There are two categories of pipeline abstractions to be aware about:
|
|||||||
- [`VisualQuestionAnsweringPipeline`]
|
- [`VisualQuestionAnsweringPipeline`]
|
||||||
- [`ZeroShotClassificationPipeline`]
|
- [`ZeroShotClassificationPipeline`]
|
||||||
- [`ZeroShotImageClassificationPipeline`]
|
- [`ZeroShotImageClassificationPipeline`]
|
||||||
|
- [`ZeroShotObjectDetectionPipeline`]
|
||||||
|
|
||||||
## The pipeline abstraction
|
## The pipeline abstraction
|
||||||
|
|
||||||
@@ -456,6 +457,12 @@ See [`TokenClassificationPipeline`] for all details.
|
|||||||
- __call__
|
- __call__
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
### ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
[[autodoc]] ZeroShotObjectDetectionPipeline
|
||||||
|
- __call__
|
||||||
|
- all
|
||||||
|
|
||||||
## Parent class: `Pipeline`
|
## Parent class: `Pipeline`
|
||||||
|
|
||||||
[[autodoc]] Pipeline
|
[[autodoc]] Pipeline
|
||||||
|
|||||||
@@ -174,6 +174,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
|||||||
|
|
||||||
[[autodoc]] AutoModelForInstanceSegmentation
|
[[autodoc]] AutoModelForInstanceSegmentation
|
||||||
|
|
||||||
|
## AutoModelForZeroShotObjectDetection
|
||||||
|
|
||||||
|
[[autodoc]] AutoModelForZeroShotObjectDetection
|
||||||
|
|
||||||
## TFAutoModel
|
## TFAutoModel
|
||||||
|
|
||||||
[[autodoc]] TFAutoModel
|
[[autodoc]] TFAutoModel
|
||||||
|
|||||||
@@ -442,6 +442,7 @@ _import_structure = {
|
|||||||
"VisualQuestionAnsweringPipeline",
|
"VisualQuestionAnsweringPipeline",
|
||||||
"ZeroShotClassificationPipeline",
|
"ZeroShotClassificationPipeline",
|
||||||
"ZeroShotImageClassificationPipeline",
|
"ZeroShotImageClassificationPipeline",
|
||||||
|
"ZeroShotObjectDetectionPipeline",
|
||||||
"pipeline",
|
"pipeline",
|
||||||
],
|
],
|
||||||
"processing_utils": ["ProcessorMixin"],
|
"processing_utils": ["ProcessorMixin"],
|
||||||
@@ -878,6 +879,7 @@ else:
|
|||||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_MAPPING",
|
"MODEL_MAPPING",
|
||||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
|
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
"AutoModelForAudioClassification",
|
"AutoModelForAudioClassification",
|
||||||
"AutoModelForAudioFrameClassification",
|
"AutoModelForAudioFrameClassification",
|
||||||
@@ -905,6 +907,7 @@ else:
|
|||||||
"AutoModelForVision2Seq",
|
"AutoModelForVision2Seq",
|
||||||
"AutoModelForVisualQuestionAnswering",
|
"AutoModelForVisualQuestionAnswering",
|
||||||
"AutoModelWithLMHead",
|
"AutoModelWithLMHead",
|
||||||
|
"AutoModelForZeroShotObjectDetection",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.bart"].extend(
|
_import_structure["models.bart"].extend(
|
||||||
@@ -3407,6 +3410,7 @@ if TYPE_CHECKING:
|
|||||||
VisualQuestionAnsweringPipeline,
|
VisualQuestionAnsweringPipeline,
|
||||||
ZeroShotClassificationPipeline,
|
ZeroShotClassificationPipeline,
|
||||||
ZeroShotImageClassificationPipeline,
|
ZeroShotImageClassificationPipeline,
|
||||||
|
ZeroShotObjectDetectionPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from .processing_utils import ProcessorMixin
|
from .processing_utils import ProcessorMixin
|
||||||
@@ -3772,6 +3776,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||||
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
@@ -3800,6 +3805,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForVideoClassification,
|
AutoModelForVideoClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoModelForVisualQuestionAnswering,
|
AutoModelForVisualQuestionAnswering,
|
||||||
|
AutoModelForZeroShotObjectDetection,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
from .models.bart import (
|
from .models.bart import (
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ else:
|
|||||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_MAPPING",
|
"MODEL_MAPPING",
|
||||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
|
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
"AutoModelForAudioClassification",
|
"AutoModelForAudioClassification",
|
||||||
"AutoModelForAudioFrameClassification",
|
"AutoModelForAudioFrameClassification",
|
||||||
@@ -96,6 +97,7 @@ else:
|
|||||||
"AutoModelForVisualQuestionAnswering",
|
"AutoModelForVisualQuestionAnswering",
|
||||||
"AutoModelForDocumentQuestionAnswering",
|
"AutoModelForDocumentQuestionAnswering",
|
||||||
"AutoModelWithLMHead",
|
"AutoModelWithLMHead",
|
||||||
|
"AutoModelForZeroShotObjectDetection",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -215,6 +217,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||||
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
@@ -243,6 +246,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForVideoClassification,
|
AutoModelForVideoClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoModelForVisualQuestionAnswering,
|
AutoModelForVisualQuestionAnswering,
|
||||||
|
AutoModelForZeroShotObjectDetection,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -472,6 +472,13 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Zero Shot Object Detection mapping
|
||||||
|
("owlvit", "OwlViTForObjectDetection")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Seq2Seq Causal LM mapping
|
# Model for Seq2Seq Causal LM mapping
|
||||||
@@ -830,6 +837,9 @@ MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|||||||
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
|
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
|
||||||
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
|
||||||
|
)
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
@@ -1016,6 +1026,15 @@ class AutoModelForObjectDetection(_BaseAutoModelClass):
|
|||||||
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
|
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForZeroShotObjectDetection = auto_class_update(
|
||||||
|
AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForVideoClassification(_BaseAutoModelClass):
|
class AutoModelForVideoClassification(_BaseAutoModelClass):
|
||||||
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ from .token_classification import (
|
|||||||
from .visual_question_answering import VisualQuestionAnsweringPipeline
|
from .visual_question_answering import VisualQuestionAnsweringPipeline
|
||||||
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
||||||
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
||||||
|
from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -124,6 +125,7 @@ if is_torch_available():
|
|||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoModelForVisualQuestionAnswering,
|
AutoModelForVisualQuestionAnswering,
|
||||||
|
AutoModelForZeroShotObjectDetection,
|
||||||
)
|
)
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..modeling_tf_utils import TFPreTrainedModel
|
from ..modeling_tf_utils import TFPreTrainedModel
|
||||||
@@ -335,6 +337,13 @@ SUPPORTED_TASKS = {
|
|||||||
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
|
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
|
||||||
"type": "image",
|
"type": "image",
|
||||||
},
|
},
|
||||||
|
"zero-shot-object-detection": {
|
||||||
|
"impl": ZeroShotObjectDetectionPipeline,
|
||||||
|
"tf": (),
|
||||||
|
"pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),
|
||||||
|
"default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}},
|
||||||
|
"type": "multimodal",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||||
|
|||||||
278
src/transformers/pipelines/zero_shot_object_detection.py
Normal file
278
src/transformers/pipelines/zero_shot_object_detection.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..tokenization_utils_base import BatchEncoding
|
||||||
|
from ..utils import (
|
||||||
|
add_end_docstrings,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
logging,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..image_utils import load_image
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
|
class ZeroShotObjectDetectionPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of
|
||||||
|
objects when you provide an image and a set of `candidate_labels`.
|
||||||
|
|
||||||
|
This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||||
|
`"zero-shot-object-detection"`.
|
||||||
|
|
||||||
|
See the list of available models on
|
||||||
|
[huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if self.framework == "tf":
|
||||||
|
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||||
|
|
||||||
|
requires_backends(self, "vision")
|
||||||
|
self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: Union[str, List[str], "Image.Image", List["Image.Image"]],
|
||||||
|
text_queries: Union[str, List[str], List[List[str]]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||||
|
The pipeline handles three types of images:
|
||||||
|
|
||||||
|
- A string containing an http url pointing to an image
|
||||||
|
- A string containing a local path to an image
|
||||||
|
- An image loaded in PIL directly
|
||||||
|
|
||||||
|
text_queries (`str` or `List[str]` or `List[List[str]]`): Text queries to query the target image with.
|
||||||
|
If given multiple images, `text_queries` should be provided as a list of lists, where each nested list
|
||||||
|
contains the text queries for the corresponding image.
|
||||||
|
|
||||||
|
threshold (`float`, *optional*, defaults to 0.1):
|
||||||
|
The probability necessary to make a prediction.
|
||||||
|
|
||||||
|
top_k (`int`, *optional*, defaults to None):
|
||||||
|
The number of top predictions that will be returned by the pipeline. If the provided number is `None`
|
||||||
|
or higher than the number of predictions available, it will default to the number of predictions.
|
||||||
|
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A list of lists containing prediction results, one list per input image. Each list contains dictionaries
|
||||||
|
with the following keys:
|
||||||
|
|
||||||
|
- **label** (`str`) -- Text query corresponding to the found object.
|
||||||
|
- **score** (`float`) -- Score corresponding to the object (between 0 and 1).
|
||||||
|
- **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a
|
||||||
|
dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys.
|
||||||
|
"""
|
||||||
|
if isinstance(text_queries, str) or (isinstance(text_queries, List) and not isinstance(text_queries[0], List)):
|
||||||
|
if isinstance(images, (str, Image.Image)):
|
||||||
|
inputs = {"images": images, "text_queries": text_queries}
|
||||||
|
elif isinstance(images, List):
|
||||||
|
assert len(images) == 1, "Input text_queries and images must have correspondance"
|
||||||
|
inputs = {"images": images[0], "text_queries": text_queries}
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Innapropriate type of images: {type(images)}")
|
||||||
|
|
||||||
|
elif isinstance(text_queries, str) or (isinstance(text_queries, List) and isinstance(text_queries[0], List)):
|
||||||
|
if isinstance(images, (Image.Image, str)):
|
||||||
|
images = [images]
|
||||||
|
assert len(images) == len(text_queries), "Input text_queries and images must have correspondance"
|
||||||
|
inputs = {"images": images, "text_queries": text_queries}
|
||||||
|
else:
|
||||||
|
"""
|
||||||
|
Supports the following format
|
||||||
|
- {"images": images, "text_queries": text_queries}
|
||||||
|
"""
|
||||||
|
inputs = images
|
||||||
|
results = super().__call__(inputs, **kwargs)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _sanitize_parameters(self, **kwargs):
|
||||||
|
postprocess_params = {}
|
||||||
|
if "threshold" in kwargs:
|
||||||
|
postprocess_params["threshold"] = kwargs["threshold"]
|
||||||
|
if "top_k" in kwargs:
|
||||||
|
postprocess_params["top_k"] = kwargs["top_k"]
|
||||||
|
return {}, {}, postprocess_params
|
||||||
|
|
||||||
|
def preprocess(self, inputs):
|
||||||
|
if not isinstance(inputs["images"], List):
|
||||||
|
inputs["images"] = [inputs["images"]]
|
||||||
|
images = [load_image(img) for img in inputs["images"]]
|
||||||
|
text_queries = inputs["text_queries"]
|
||||||
|
if isinstance(text_queries, str) or isinstance(text_queries[0], str):
|
||||||
|
text_queries = [text_queries]
|
||||||
|
|
||||||
|
target_sizes = [torch.IntTensor([[img.height, img.width]]) for img in images]
|
||||||
|
target_sizes = torch.cat(target_sizes)
|
||||||
|
inputs = self._processor(text=inputs["text_queries"], images=images, return_tensors="pt")
|
||||||
|
return {"target_sizes": target_sizes, "text_queries": text_queries, **inputs}
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
target_sizes = model_inputs.pop("target_sizes")
|
||||||
|
text_queries = model_inputs.pop("text_queries")
|
||||||
|
outputs = self.model(**model_inputs)
|
||||||
|
|
||||||
|
model_outputs = outputs.__class__({"target_sizes": target_sizes, "text_queries": text_queries, **outputs})
|
||||||
|
return model_outputs
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs, threshold=0.1, top_k=None):
|
||||||
|
texts = model_outputs["text_queries"]
|
||||||
|
|
||||||
|
outputs = self.feature_extractor.post_process(
|
||||||
|
outputs=model_outputs, target_sizes=model_outputs["target_sizes"]
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
keep = outputs[i]["scores"] >= threshold
|
||||||
|
labels = outputs[i]["labels"][keep].tolist()
|
||||||
|
scores = outputs[i]["scores"][keep].tolist()
|
||||||
|
boxes = [self._get_bounding_box(box) for box in outputs[i]["boxes"][keep]]
|
||||||
|
|
||||||
|
result = [
|
||||||
|
{"score": score, "label": texts[i][label], "box": box}
|
||||||
|
for score, label, box in zip(scores, labels, boxes)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = sorted(result, key=lambda x: x["score"], reverse=True)
|
||||||
|
if top_k:
|
||||||
|
result = result[:top_k]
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
|
||||||
|
"""
|
||||||
|
Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
box (`torch.Tensor`): Tensor containing the coordinates in corners format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bbox (`Dict[str, int]`): Dict containing the coordinates in corners format.
|
||||||
|
"""
|
||||||
|
if self.framework != "pt":
|
||||||
|
raise ValueError("The ZeroShotObjectDetectionPipeline is only available in PyTorch.")
|
||||||
|
xmin, ymin, xmax, ymax = box.int().tolist()
|
||||||
|
bbox = {
|
||||||
|
"xmin": xmin,
|
||||||
|
"ymin": ymin,
|
||||||
|
"xmax": xmax,
|
||||||
|
"ymax": ymax,
|
||||||
|
}
|
||||||
|
return bbox
|
||||||
|
|
||||||
|
# Replication of OwlViTProcessor __call__ method, since pipelines don't auto infer processor's yet!
|
||||||
|
def _processor(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs):
|
||||||
|
"""
|
||||||
|
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
|
||||||
|
`kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
|
||||||
|
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||||
|
CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the
|
||||||
|
doctsring of the above two methods for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (`str`, `List[str]`, `List[List[str]]`):
|
||||||
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||||
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||||
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
|
||||||
|
`List[torch.Tensor]`):
|
||||||
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||||
|
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||||
|
number of channels, H and W are image height and width.
|
||||||
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||||
|
If set, will return tensors of a particular framework. Acceptable values are:
|
||||||
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||||
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||||
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||||
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||||
|
Returns:
|
||||||
|
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
||||||
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||||
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||||
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||||
|
`None`).
|
||||||
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if text is None and images is None:
|
||||||
|
raise ValueError("You have to specify at least one text or image. Both cannot be none.")
|
||||||
|
|
||||||
|
if text is not None:
|
||||||
|
if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
|
||||||
|
encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)]
|
||||||
|
|
||||||
|
elif isinstance(text, List) and isinstance(text[0], List):
|
||||||
|
encodings = []
|
||||||
|
|
||||||
|
# Maximum number of queries across batch
|
||||||
|
max_num_queries = max([len(t) for t in text])
|
||||||
|
|
||||||
|
# Pad all batch samples to max number of text queries
|
||||||
|
for t in text:
|
||||||
|
if len(t) != max_num_queries:
|
||||||
|
t = t + [" "] * (max_num_queries - len(t))
|
||||||
|
|
||||||
|
encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs)
|
||||||
|
encodings.append(encoding)
|
||||||
|
else:
|
||||||
|
raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
|
||||||
|
|
||||||
|
if return_tensors == "np":
|
||||||
|
input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
|
||||||
|
attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
|
||||||
|
|
||||||
|
elif return_tensors == "pt" and is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0)
|
||||||
|
attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0)
|
||||||
|
|
||||||
|
elif return_tensors == "tf" and is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0)
|
||||||
|
attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Target return tensor type could not be returned")
|
||||||
|
|
||||||
|
encoding = BatchEncoding()
|
||||||
|
encoding["input_ids"] = input_ids
|
||||||
|
encoding["attention_mask"] = attention_mask
|
||||||
|
|
||||||
|
if images is not None:
|
||||||
|
image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs)
|
||||||
|
|
||||||
|
if text is not None and images is not None:
|
||||||
|
encoding["pixel_values"] = image_features.pixel_values
|
||||||
|
return encoding
|
||||||
|
elif text is not None:
|
||||||
|
return encoding
|
||||||
|
else:
|
||||||
|
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
|
||||||
@@ -418,6 +418,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
|||||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
MODEL_MAPPING = None
|
MODEL_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -606,6 +609,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForZeroShotObjectDetection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoModelWithLMHead(metaclass=DummyObject):
|
class AutoModelWithLMHead(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
263
tests/pipelines/test_pipelines_zero_shot_object_detection.py
Normal file
263
tests/pipelines/test_pipelines_zero_shot_object_detection.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, is_vision_available, pipeline
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
nested_simplify,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Image:
|
||||||
|
@staticmethod
|
||||||
|
def open(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
@is_pipeline_test
|
||||||
|
class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
|
|
||||||
|
model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
||||||
|
|
||||||
|
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||||
|
object_detector = pipeline(
|
||||||
|
"zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection"
|
||||||
|
)
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"images": "./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
"text_queries": ["cat", "remote", "couch"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return object_detector, examples
|
||||||
|
|
||||||
|
def run_pipeline_test(self, object_detector, examples):
|
||||||
|
batch_outputs = object_detector(examples, threshold=0.0)
|
||||||
|
|
||||||
|
self.assertEqual(len(examples), len(batch_outputs))
|
||||||
|
for outputs in batch_outputs:
|
||||||
|
for output_per_image in outputs:
|
||||||
|
self.assertGreater(len(output_per_image), 0)
|
||||||
|
for detected_object in output_per_image:
|
||||||
|
self.assertEqual(
|
||||||
|
detected_object,
|
||||||
|
{
|
||||||
|
"score": ANY(float),
|
||||||
|
"label": ANY(str),
|
||||||
|
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@unittest.skip("Zero Shot Object Detection not implemented in TF")
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
object_detector = pipeline(
|
||||||
|
"zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection"
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
text_queries=["cat", "remote", "couch"],
|
||||||
|
threshold=0.64,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||||
|
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||||
|
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||||
|
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
["./tests/fixtures/tests_samples/COCO/000000039769.png"],
|
||||||
|
text_queries=["cat", "remote", "couch"],
|
||||||
|
threshold=0.64,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||||
|
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||||
|
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||||
|
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
text_queries=[["cat", "remote", "couch"]],
|
||||||
|
threshold=0.64,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||||
|
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||||
|
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||||
|
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
[
|
||||||
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
],
|
||||||
|
text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]],
|
||||||
|
threshold=0.64,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||||
|
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||||
|
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||||
|
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||||
|
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||||
|
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||||
|
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_large_model_pt(self):
|
||||||
|
object_detector = pipeline("zero-shot-object-detection")
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg", text_queries=["cat", "remote", "couch"]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||||
|
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||||
|
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
|
||||||
|
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
|
||||||
|
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
[
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
],
|
||||||
|
text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]],
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||||
|
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||||
|
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
|
||||||
|
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
|
||||||
|
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||||
|
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||||
|
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
|
||||||
|
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
|
||||||
|
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@unittest.skip("Zero Shot Object Detection not implemented in TF")
|
||||||
|
def test_large_model_tf(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_threshold(self):
|
||||||
|
threshold = 0.2
|
||||||
|
object_detector = pipeline("zero-shot-object-detection")
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
text_queries=["cat", "remote", "couch"],
|
||||||
|
threshold=threshold,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||||
|
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||||
|
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_top_k(self):
|
||||||
|
top_k = 2
|
||||||
|
object_detector = pipeline("zero-shot-object-detection")
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
text_queries=["cat", "remote", "couch"],
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||||
|
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
@@ -58,6 +58,11 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
|||||||
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
|
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
|
||||||
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
|
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
|
||||||
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
|
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
|
||||||
|
(
|
||||||
|
"zero-shot-object-detection",
|
||||||
|
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES",
|
||||||
|
"AutoModelForZeroShotObjectDetection",
|
||||||
|
),
|
||||||
("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"),
|
("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"),
|
||||||
("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
|
("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
|
||||||
("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),
|
("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),
|
||||||
|
|||||||
Reference in New Issue
Block a user