Faster zero shot image (#21897)
* Make ZeroShotImageClassificationPipeline faster The pipeline makes separate calls to model for each candidate label. This commit combines all labels into one call. Original code takes more that 60 seconds to process one image and 1000 candidate labels. Updated code takes less than 2 seconds. * implement batching * code formatting * Creating an even faster zero-shot-image-classifiction. Unfortunately super tailored towards CLIP. Co-Authored-By: Yessen Kanapin <yessen@deepinfra.com> * Quality. * Cleanup. * Order different on the CI it seems. * Cleanup. * Quality. --------- Co-authored-by: Yessen Kanapin <yessen@deepinfra.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from collections import UserDict
|
||||
from typing import List, Union
|
||||
|
||||
from ..utils import (
|
||||
@@ -8,7 +9,7 @@ from ..utils import (
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@@ -17,18 +18,16 @@ if is_vision_available():
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
pass
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from ..tf_utils import stable_softmax
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ZeroShotImageClassificationPipeline(ChunkPipeline):
|
||||
class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
"""
|
||||
Zero shot image classification pipeline using `CLIPModel`. This pipeline predicts the class of an image when you
|
||||
provide an image and a set of `candidate_labels`.
|
||||
@@ -107,42 +106,39 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline):
|
||||
return preprocess_params, {}, {}
|
||||
|
||||
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}."):
|
||||
n = len(candidate_labels)
|
||||
for i, candidate_label in enumerate(candidate_labels):
|
||||
image = load_image(image)
|
||||
images = self.image_processor(images=[image], return_tensors=self.framework)
|
||||
sequence = hypothesis_template.format(candidate_label)
|
||||
inputs = self.tokenizer(sequence, return_tensors=self.framework)
|
||||
inputs["pixel_values"] = images.pixel_values
|
||||
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
|
||||
inputs = self.image_processor(images=[image], return_tensors=self.framework)
|
||||
inputs["candidate_labels"] = candidate_labels
|
||||
sequences = [hypothesis_template.format(x) for x in candidate_labels]
|
||||
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
|
||||
inputs["text_inputs"] = [text_inputs]
|
||||
return inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
is_last = model_inputs.pop("is_last")
|
||||
candidate_label = model_inputs.pop("candidate_label")
|
||||
outputs = self.model(**model_inputs)
|
||||
candidate_labels = model_inputs.pop("candidate_labels")
|
||||
text_inputs = model_inputs.pop("text_inputs")
|
||||
if isinstance(text_inputs[0], UserDict):
|
||||
text_inputs = text_inputs[0]
|
||||
else:
|
||||
# Batching case.
|
||||
text_inputs = text_inputs[0][0]
|
||||
|
||||
# Clip does crossproduct scoring by default, so we're only
|
||||
# interested in the results where image and text and in the same
|
||||
# batch position.
|
||||
diag = torch.diagonal if self.framework == "pt" else tf.linalg.diag_part
|
||||
logits_per_image = diag(outputs.logits_per_image)
|
||||
outputs = self.model(**text_inputs, **model_inputs)
|
||||
|
||||
model_outputs = {
|
||||
"is_last": is_last,
|
||||
"candidate_label": candidate_label,
|
||||
"logits_per_image": logits_per_image,
|
||||
"candidate_labels": candidate_labels,
|
||||
"logits": outputs.logits_per_image,
|
||||
}
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
|
||||
candidate_labels = model_outputs.pop("candidate_labels")
|
||||
logits = model_outputs["logits"][0]
|
||||
if self.framework == "pt":
|
||||
logits = torch.cat([output["logits_per_image"] for output in model_outputs])
|
||||
probs = logits.softmax(dim=0)
|
||||
probs = logits.softmax(dim=-1).squeeze(-1)
|
||||
scores = probs.tolist()
|
||||
else:
|
||||
logits = tf.concat([output["logits_per_image"] for output in model_outputs], axis=0)
|
||||
probs = stable_softmax(logits, axis=0)
|
||||
probs = stable_softmax(logits, axis=-1)
|
||||
scores = probs.numpy().tolist()
|
||||
|
||||
result = [
|
||||
|
||||
Reference in New Issue
Block a user