Faster zero shot image (#21897)

* Make ZeroShotImageClassificationPipeline faster

The pipeline makes separate calls to model for each candidate label.
This commit combines all labels into one call.
Original code takes more that 60 seconds to process one image and 1000
candidate labels. Updated code takes less than 2 seconds.

* implement batching

* code formatting

* Creating an even faster zero-shot-image-classifiction.

Unfortunately super tailored towards CLIP.

Co-Authored-By: Yessen Kanapin <yessen@deepinfra.com>

* Quality.

* Cleanup.

* Order different on the CI it seems.

* Cleanup.

* Quality.

---------

Co-authored-by: Yessen Kanapin <yessen@deepinfra.com>
This commit is contained in:
Nicolas Patry
2023-03-02 19:46:22 +01:00
committed by GitHub
parent 88e5c51a15
commit b2a41d2be4

View File

@@ -1,3 +1,4 @@
from collections import UserDict
from typing import List, Union from typing import List, Union
from ..utils import ( from ..utils import (
@@ -8,7 +9,7 @@ from ..utils import (
logging, logging,
requires_backends, requires_backends,
) )
from .base import PIPELINE_INIT_ARGS, ChunkPipeline from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available(): if is_vision_available():
@@ -17,18 +18,16 @@ if is_vision_available():
from ..image_utils import load_image from ..image_utils import load_image
if is_torch_available(): if is_torch_available():
import torch pass
if is_tf_available(): if is_tf_available():
import tensorflow as tf
from ..tf_utils import stable_softmax from ..tf_utils import stable_softmax
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS) @add_end_docstrings(PIPELINE_INIT_ARGS)
class ZeroShotImageClassificationPipeline(ChunkPipeline): class ZeroShotImageClassificationPipeline(Pipeline):
""" """
Zero shot image classification pipeline using `CLIPModel`. This pipeline predicts the class of an image when you Zero shot image classification pipeline using `CLIPModel`. This pipeline predicts the class of an image when you
provide an image and a set of `candidate_labels`. provide an image and a set of `candidate_labels`.
@@ -107,42 +106,39 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline):
return preprocess_params, {}, {} return preprocess_params, {}, {}
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}."): def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}."):
n = len(candidate_labels) image = load_image(image)
for i, candidate_label in enumerate(candidate_labels): inputs = self.image_processor(images=[image], return_tensors=self.framework)
image = load_image(image) inputs["candidate_labels"] = candidate_labels
images = self.image_processor(images=[image], return_tensors=self.framework) sequences = [hypothesis_template.format(x) for x in candidate_labels]
sequence = hypothesis_template.format(candidate_label) text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
inputs = self.tokenizer(sequence, return_tensors=self.framework) inputs["text_inputs"] = [text_inputs]
inputs["pixel_values"] = images.pixel_values return inputs
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
def _forward(self, model_inputs): def _forward(self, model_inputs):
is_last = model_inputs.pop("is_last") candidate_labels = model_inputs.pop("candidate_labels")
candidate_label = model_inputs.pop("candidate_label") text_inputs = model_inputs.pop("text_inputs")
outputs = self.model(**model_inputs) if isinstance(text_inputs[0], UserDict):
text_inputs = text_inputs[0]
else:
# Batching case.
text_inputs = text_inputs[0][0]
# Clip does crossproduct scoring by default, so we're only outputs = self.model(**text_inputs, **model_inputs)
# interested in the results where image and text and in the same
# batch position.
diag = torch.diagonal if self.framework == "pt" else tf.linalg.diag_part
logits_per_image = diag(outputs.logits_per_image)
model_outputs = { model_outputs = {
"is_last": is_last, "candidate_labels": candidate_labels,
"candidate_label": candidate_label, "logits": outputs.logits_per_image,
"logits_per_image": logits_per_image,
} }
return model_outputs return model_outputs
def postprocess(self, model_outputs): def postprocess(self, model_outputs):
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] candidate_labels = model_outputs.pop("candidate_labels")
logits = model_outputs["logits"][0]
if self.framework == "pt": if self.framework == "pt":
logits = torch.cat([output["logits_per_image"] for output in model_outputs]) probs = logits.softmax(dim=-1).squeeze(-1)
probs = logits.softmax(dim=0)
scores = probs.tolist() scores = probs.tolist()
else: else:
logits = tf.concat([output["logits_per_image"] for output in model_outputs], axis=0) probs = stable_softmax(logits, axis=-1)
probs = stable_softmax(logits, axis=0)
scores = probs.numpy().tolist() scores = probs.numpy().tolist()
result = [ result = [