Adding ZeroShotImageClassificationPipeline (#12119)
* [Proposal] Adding ZeroShotImageClassificationPipeline - Based on CLIP * WIP, Resurection in progress. * Resurrection... achieved. * Reword handling different `padding_value` for `feature_extractor` and `tokenizer`. * Thanks doc-builder ! * Adding docs + global namespace `ZeroShotImageClassificationPipeline`. * Fixing templates. * Make the test pass and be robust to floating error. * Adressing suraj's comments on docs mostly. * Tf support start. * TF support. * Update src/transformers/pipelines/zero_shot_image_classification.py Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -428,6 +428,12 @@ See [`TokenClassificationPipeline`] for all details.
|
|||||||
- __call__
|
- __call__
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
### ZeroShotImageClassificationPipeline
|
||||||
|
|
||||||
|
[[autodoc]] ZeroShotImageClassificationPipeline
|
||||||
|
- __call__
|
||||||
|
- all
|
||||||
|
|
||||||
## Parent class: `Pipeline`
|
## Parent class: `Pipeline`
|
||||||
|
|
||||||
[[autodoc]] Pipeline
|
[[autodoc]] Pipeline
|
||||||
|
|||||||
@@ -365,6 +365,7 @@ _import_structure = {
|
|||||||
"TokenClassificationPipeline",
|
"TokenClassificationPipeline",
|
||||||
"TranslationPipeline",
|
"TranslationPipeline",
|
||||||
"ZeroShotClassificationPipeline",
|
"ZeroShotClassificationPipeline",
|
||||||
|
"ZeroShotImageClassificationPipeline",
|
||||||
"pipeline",
|
"pipeline",
|
||||||
],
|
],
|
||||||
"processing_utils": ["ProcessorMixin"],
|
"processing_utils": ["ProcessorMixin"],
|
||||||
@@ -2597,6 +2598,7 @@ if TYPE_CHECKING:
|
|||||||
TokenClassificationPipeline,
|
TokenClassificationPipeline,
|
||||||
TranslationPipeline,
|
TranslationPipeline,
|
||||||
ZeroShotClassificationPipeline,
|
ZeroShotClassificationPipeline,
|
||||||
|
ZeroShotImageClassificationPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from .processing_utils import ProcessorMixin
|
from .processing_utils import ProcessorMixin
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ from .token_classification import (
|
|||||||
TokenClassificationPipeline,
|
TokenClassificationPipeline,
|
||||||
)
|
)
|
||||||
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
||||||
|
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -239,6 +240,13 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"type": "text",
|
"type": "text",
|
||||||
},
|
},
|
||||||
|
"zero-shot-image-classification": {
|
||||||
|
"impl": ZeroShotImageClassificationPipeline,
|
||||||
|
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||||
|
"pt": (AutoModel,) if is_torch_available() else (),
|
||||||
|
"default": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"},
|
||||||
|
"type": "multimodal",
|
||||||
|
},
|
||||||
"conversational": {
|
"conversational": {
|
||||||
"impl": ConversationalPipeline,
|
"impl": ConversationalPipeline,
|
||||||
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
||||||
|
|||||||
@@ -105,7 +105,10 @@ def _pad(items, key, padding_value, padding_side):
|
|||||||
|
|
||||||
|
|
||||||
def pad_collate_fn(tokenizer, feature_extractor):
|
def pad_collate_fn(tokenizer, feature_extractor):
|
||||||
padding_side = "right"
|
# Tokenizer
|
||||||
|
t_padding_side = None
|
||||||
|
# Feature extractor
|
||||||
|
f_padding_side = None
|
||||||
if tokenizer is None and feature_extractor is None:
|
if tokenizer is None and feature_extractor is None:
|
||||||
raise ValueError("Pipeline without tokenizer or feature_extractor cannot do batching")
|
raise ValueError("Pipeline without tokenizer or feature_extractor cannot do batching")
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
@@ -115,12 +118,22 @@ def pad_collate_fn(tokenizer, feature_extractor):
|
|||||||
"`pipe.tokenizer.pad_token_id = model.config.eos_token_id`."
|
"`pipe.tokenizer.pad_token_id = model.config.eos_token_id`."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
padding_value = tokenizer.pad_token_id
|
t_padding_value = tokenizer.pad_token_id
|
||||||
padding_side = tokenizer.padding_side
|
t_padding_side = tokenizer.padding_side
|
||||||
if feature_extractor is not None:
|
if feature_extractor is not None:
|
||||||
# Feature extractor can be images, where no padding is expected
|
# Feature extractor can be images, where no padding is expected
|
||||||
padding_value = getattr(feature_extractor, "padding_value", None)
|
f_padding_value = getattr(feature_extractor, "padding_value", None)
|
||||||
padding_side = getattr(feature_extractor, "padding_side", None)
|
f_padding_side = getattr(feature_extractor, "padding_side", None)
|
||||||
|
|
||||||
|
if t_padding_side is not None and f_padding_side is not None and t_padding_side != f_padding_side:
|
||||||
|
raise ValueError(
|
||||||
|
f"The feature extractor, and tokenizer don't agree on padding side {t_padding_side} != {f_padding_side}"
|
||||||
|
)
|
||||||
|
padding_side = "right"
|
||||||
|
if t_padding_side is not None:
|
||||||
|
padding_side = t_padding_side
|
||||||
|
if f_padding_side is not None:
|
||||||
|
padding_side = f_padding_side
|
||||||
|
|
||||||
def inner(items):
|
def inner(items):
|
||||||
keys = set(items[0].keys())
|
keys = set(items[0].keys())
|
||||||
@@ -132,11 +145,16 @@ def pad_collate_fn(tokenizer, feature_extractor):
|
|||||||
# input_values, input_pixels, input_ids, ...
|
# input_values, input_pixels, input_ids, ...
|
||||||
padded = {}
|
padded = {}
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("input_"):
|
if key in {"input_ids"}:
|
||||||
_padding_value = padding_value
|
_padding_value = t_padding_value
|
||||||
elif key == "p_mask":
|
elif key in {"input_values", "pixel_values", "input_features"}:
|
||||||
|
_padding_value = f_padding_value
|
||||||
|
elif key in {"p_mask"}:
|
||||||
_padding_value = 1
|
_padding_value = 1
|
||||||
|
elif key in {"attention_mask", "token_type_ids"}:
|
||||||
|
_padding_value = 0
|
||||||
else:
|
else:
|
||||||
|
# This is likely another random key maybe even user provided
|
||||||
_padding_value = 0
|
_padding_value = 0
|
||||||
padded[key] = _pad(items, key, _padding_value, padding_side)
|
padded[key] = _pad(items, key, _padding_value, padding_side)
|
||||||
return padded
|
return padded
|
||||||
|
|||||||
129
src/transformers/pipelines/zero_shot_image_classification.py
Normal file
129
src/transformers/pipelines/zero_shot_image_classification.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from ..file_utils import (
|
||||||
|
add_end_docstrings,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
from ..utils import logging
|
||||||
|
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..image_utils import load_image
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
|
class ZeroShotImageClassificationPipeline(ChunkPipeline):
|
||||||
|
"""
|
||||||
|
Zero shot image classification pipeline using `CLIPModel`. This pipeline predicts the class of an image when you
|
||||||
|
provide an image and a set of `candidate_labels`.
|
||||||
|
|
||||||
|
This image classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||||
|
`"zero-shot-image-classification"`.
|
||||||
|
|
||||||
|
See the list of available models on
|
||||||
|
[huggingface.co/models](https://huggingface.co/models?filter=zer-shot-image-classification).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
requires_backends(self, "vision")
|
||||||
|
# No specific FOR_XXX available yet
|
||||||
|
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
||||||
|
|
||||||
|
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
||||||
|
"""
|
||||||
|
Assign labels to the image(s) passed as inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||||
|
The pipeline handles three types of images:
|
||||||
|
|
||||||
|
- A string containing a http link pointing to an image
|
||||||
|
- A string containing a local path to an image
|
||||||
|
- An image loaded in PIL directly
|
||||||
|
|
||||||
|
candidate_labels (`List[str]`):
|
||||||
|
The candidate labels for this image
|
||||||
|
|
||||||
|
hypothesis_template (`str`, *optional*, defaults to `"This is a photo of {}"`):
|
||||||
|
The sentence used in cunjunction with *candidate_labels* to attempt the image classification by
|
||||||
|
replacing the placeholder with the candidate_labels. Then likelihood is estimated by using
|
||||||
|
logits_per_image
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A list of dictionaries containing result, one dictionnary per proposed label. The dictionaries contain the
|
||||||
|
following keys:
|
||||||
|
|
||||||
|
- **label** (`str`) -- The label identified by the model. It is one of the suggested `candidate_label`.
|
||||||
|
- **score** (`float`) -- The score attributed by the model for that label (between 0 and 1).
|
||||||
|
"""
|
||||||
|
return super().__call__(images, **kwargs)
|
||||||
|
|
||||||
|
def _sanitize_parameters(self, **kwargs):
|
||||||
|
preprocess_params = {}
|
||||||
|
if "candidate_labels" in kwargs:
|
||||||
|
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
|
||||||
|
if "hypothesis_template" in kwargs:
|
||||||
|
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
|
||||||
|
|
||||||
|
return preprocess_params, {}, {}
|
||||||
|
|
||||||
|
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}."):
|
||||||
|
n = len(candidate_labels)
|
||||||
|
for i, candidate_label in enumerate(candidate_labels):
|
||||||
|
image = load_image(image)
|
||||||
|
images = self.feature_extractor(images=[image], return_tensors=self.framework)
|
||||||
|
sequence = hypothesis_template.format(candidate_label)
|
||||||
|
inputs = self.tokenizer(sequence, return_tensors=self.framework)
|
||||||
|
inputs["pixel_values"] = images.pixel_values
|
||||||
|
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
is_last = model_inputs.pop("is_last")
|
||||||
|
candidate_label = model_inputs.pop("candidate_label")
|
||||||
|
outputs = self.model(**model_inputs)
|
||||||
|
|
||||||
|
# Clip does crossproduct scoring by default, so we're only
|
||||||
|
# interested in the results where image and text and in the same
|
||||||
|
# batch position.
|
||||||
|
diag = torch.diagonal if self.framework == "pt" else tf.linalg.diag_part
|
||||||
|
logits_per_image = diag(outputs.logits_per_image)
|
||||||
|
|
||||||
|
model_outputs = {
|
||||||
|
"is_last": is_last,
|
||||||
|
"candidate_label": candidate_label,
|
||||||
|
"logits_per_image": logits_per_image,
|
||||||
|
}
|
||||||
|
return model_outputs
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs):
|
||||||
|
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
|
||||||
|
if self.framework == "pt":
|
||||||
|
logits = torch.cat([output["logits_per_image"] for output in model_outputs])
|
||||||
|
probs = logits.softmax(dim=0)
|
||||||
|
scores = probs.tolist()
|
||||||
|
else:
|
||||||
|
logits = tf.concat([output["logits_per_image"] for output in model_outputs], axis=0)
|
||||||
|
probs = tf.nn.softmax(logits, axis=0)
|
||||||
|
scores = probs.numpy().tolist()
|
||||||
|
|
||||||
|
result = [
|
||||||
|
{"score": score, "label": candidate_label}
|
||||||
|
for score, candidate_label in sorted(zip(scores, candidate_labels), key=lambda x: -x[0])
|
||||||
|
]
|
||||||
|
return result
|
||||||
238
tests/test_pipelines_zero_shot_image_classification.py
Normal file
238
tests/test_pipelines_zero_shot_image_classification.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_vision_available
|
||||||
|
from transformers.pipelines import pipeline
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
nested_simplify,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Image:
|
||||||
|
@staticmethod
|
||||||
|
def open(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@is_pipeline_test
|
||||||
|
class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
|
# Deactivating auto tests since we don't have a good MODEL_FOR_XX mapping,
|
||||||
|
# and only CLIP would be there for now.
|
||||||
|
# model_mapping = {CLIPConfig: CLIPModel}
|
||||||
|
|
||||||
|
# def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||||
|
# if tokenizer is None:
|
||||||
|
# # Side effect of no Fast Tokenizer class for these model, so skipping
|
||||||
|
# # But the slow tokenizer test should still run as they're quite small
|
||||||
|
# self.skipTest("No tokenizer available")
|
||||||
|
# return
|
||||||
|
# # return None, None
|
||||||
|
|
||||||
|
# image_classifier = ZeroShotImageClassificationPipeline(
|
||||||
|
# model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # test with a raw waveform
|
||||||
|
# image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
# image2 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
# return image_classifier, [image, image2]
|
||||||
|
|
||||||
|
# def run_pipeline_test(self, pipe, examples):
|
||||||
|
# image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
# outputs = pipe(image, candidate_labels=["A", "B"])
|
||||||
|
# self.assertEqual(outputs, {"text": ANY(str)})
|
||||||
|
|
||||||
|
# # Batching
|
||||||
|
# outputs = pipe([image] * 3, batch_size=2, candidate_labels=["A", "B"])
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
image_classifier = pipeline(
|
||||||
|
model="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
|
||||||
|
)
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
output = image_classifier(image, candidate_labels=["a", "b", "c"])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = image_classifier([image] * 5, candidate_labels=["A", "B", "C"], batch_size=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
# Pipeline outputs are supposed to be deterministic and
|
||||||
|
# So we could in theory have real values "A", "B", "C" instead
|
||||||
|
# of ANY(str).
|
||||||
|
# However it seems that in this particular case, the floating
|
||||||
|
# scores are so close, we enter floating error approximation
|
||||||
|
# and the order is not guaranteed anymore with batching.
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
image_classifier = pipeline(
|
||||||
|
model="hf-internal-testing/tiny-random-clip-zero-shot-image-classification", framework="tf"
|
||||||
|
)
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
output = image_classifier(image, candidate_labels=["a", "b", "c"])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = image_classifier([image] * 5, candidate_labels=["A", "B", "C"], batch_size=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
# Pipeline outputs are supposed to be deterministic and
|
||||||
|
# So we could in theory have real values "A", "B", "C" instead
|
||||||
|
# of ANY(str).
|
||||||
|
# However it seems that in this particular case, the floating
|
||||||
|
# scores are so close, we enter floating error approximation
|
||||||
|
# and the order is not guaranteed anymore with batching.
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
{"score": 0.333, "label": ANY(str)},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_large_model_pt(self):
|
||||||
|
image_classifier = pipeline(
|
||||||
|
task="zero-shot-image-classification",
|
||||||
|
model="openai/clip-vit-base-patch32",
|
||||||
|
)
|
||||||
|
# This is an image of 2 cats with remotes and no planes
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
output = image_classifier(image, candidate_labels=["cat", "plane", "remote"])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
[
|
||||||
|
{"score": 0.941, "label": "cat"},
|
||||||
|
{"score": 0.055, "label": "remote"},
|
||||||
|
{"score": 0.003, "label": "plane"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = image_classifier([image] * 5, candidate_labels=["cat", "plane", "remote"], batch_size=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.941, "label": "cat"},
|
||||||
|
{"score": 0.055, "label": "remote"},
|
||||||
|
{"score": 0.003, "label": "plane"},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
* 5,
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_tf
|
||||||
|
def test_large_model_tf(self):
|
||||||
|
image_classifier = pipeline(
|
||||||
|
task="zero-shot-image-classification", model="openai/clip-vit-base-patch32", framework="tf"
|
||||||
|
)
|
||||||
|
# This is an image of 2 cats with remotes and no planes
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
output = image_classifier(image, candidate_labels=["cat", "plane", "remote"])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
[
|
||||||
|
{"score": 0.941, "label": "cat"},
|
||||||
|
{"score": 0.055, "label": "remote"},
|
||||||
|
{"score": 0.003, "label": "plane"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = image_classifier([image] * 5, candidate_labels=["cat", "plane", "remote"], batch_size=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.941, "label": "cat"},
|
||||||
|
{"score": 0.055, "label": "remote"},
|
||||||
|
{"score": 0.003, "label": "plane"},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
* 5,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user