From b2a41d2be478c12b6222c99b59adfc8663f2adc6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 2 Mar 2023 19:46:22 +0100 Subject: [PATCH] Faster zero shot image (#21897) * Make ZeroShotImageClassificationPipeline faster The pipeline makes separate calls to model for each candidate label. This commit combines all labels into one call. Original code takes more that 60 seconds to process one image and 1000 candidate labels. Updated code takes less than 2 seconds. * implement batching * code formatting * Creating an even faster zero-shot-image-classifiction. Unfortunately super tailored towards CLIP. Co-Authored-By: Yessen Kanapin * Quality. * Cleanup. * Order different on the CI it seems. * Cleanup. * Quality. --------- Co-authored-by: Yessen Kanapin --- .../zero_shot_image_classification.py | 54 +++++++++---------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/src/transformers/pipelines/zero_shot_image_classification.py b/src/transformers/pipelines/zero_shot_image_classification.py index 78ff8b7a8c..f19a548c85 100644 --- a/src/transformers/pipelines/zero_shot_image_classification.py +++ b/src/transformers/pipelines/zero_shot_image_classification.py @@ -1,3 +1,4 @@ +from collections import UserDict from typing import List, Union from ..utils import ( @@ -8,7 +9,7 @@ from ..utils import ( logging, requires_backends, ) -from .base import PIPELINE_INIT_ARGS, ChunkPipeline +from .base import PIPELINE_INIT_ARGS, Pipeline if is_vision_available(): @@ -17,18 +18,16 @@ if is_vision_available(): from ..image_utils import load_image if is_torch_available(): - import torch + pass if is_tf_available(): - import tensorflow as tf - from ..tf_utils import stable_softmax logger = logging.get_logger(__name__) @add_end_docstrings(PIPELINE_INIT_ARGS) -class ZeroShotImageClassificationPipeline(ChunkPipeline): +class ZeroShotImageClassificationPipeline(Pipeline): """ Zero shot image classification pipeline using `CLIPModel`. This pipeline predicts the class of an image when you provide an image and a set of `candidate_labels`. @@ -107,42 +106,39 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline): return preprocess_params, {}, {} def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}."): - n = len(candidate_labels) - for i, candidate_label in enumerate(candidate_labels): - image = load_image(image) - images = self.image_processor(images=[image], return_tensors=self.framework) - sequence = hypothesis_template.format(candidate_label) - inputs = self.tokenizer(sequence, return_tensors=self.framework) - inputs["pixel_values"] = images.pixel_values - yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs} + image = load_image(image) + inputs = self.image_processor(images=[image], return_tensors=self.framework) + inputs["candidate_labels"] = candidate_labels + sequences = [hypothesis_template.format(x) for x in candidate_labels] + text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True) + inputs["text_inputs"] = [text_inputs] + return inputs def _forward(self, model_inputs): - is_last = model_inputs.pop("is_last") - candidate_label = model_inputs.pop("candidate_label") - outputs = self.model(**model_inputs) + candidate_labels = model_inputs.pop("candidate_labels") + text_inputs = model_inputs.pop("text_inputs") + if isinstance(text_inputs[0], UserDict): + text_inputs = text_inputs[0] + else: + # Batching case. + text_inputs = text_inputs[0][0] - # Clip does crossproduct scoring by default, so we're only - # interested in the results where image and text and in the same - # batch position. - diag = torch.diagonal if self.framework == "pt" else tf.linalg.diag_part - logits_per_image = diag(outputs.logits_per_image) + outputs = self.model(**text_inputs, **model_inputs) model_outputs = { - "is_last": is_last, - "candidate_label": candidate_label, - "logits_per_image": logits_per_image, + "candidate_labels": candidate_labels, + "logits": outputs.logits_per_image, } return model_outputs def postprocess(self, model_outputs): - candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] + candidate_labels = model_outputs.pop("candidate_labels") + logits = model_outputs["logits"][0] if self.framework == "pt": - logits = torch.cat([output["logits_per_image"] for output in model_outputs]) - probs = logits.softmax(dim=0) + probs = logits.softmax(dim=-1).squeeze(-1) scores = probs.tolist() else: - logits = tf.concat([output["logits_per_image"] for output in model_outputs], axis=0) - probs = stable_softmax(logits, axis=0) + probs = stable_softmax(logits, axis=-1) scores = probs.numpy().tolist() result = [