From 5a06118b3922635699d72d72f9025e71cf04bfba Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 6 Jan 2022 14:16:00 +0100 Subject: [PATCH] Enabling `TF` on `image-classification` pipeline. (#15030) --- .../pipelines/image_classification.py | 37 +++++++++++---- tests/test_pipelines_image_classification.py | 46 +++++++++++++++++-- 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index 590a823911..2afb084b91 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -1,6 +1,12 @@ from typing import List, Union -from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends +from ..file_utils import ( + add_end_docstrings, + is_tf_available, + is_torch_available, + is_vision_available, + requires_backends, +) from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -10,6 +16,11 @@ if is_vision_available(): from ..image_utils import load_image +if is_tf_available(): + import tensorflow as tf + + from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING @@ -31,12 +42,12 @@ class ImageClassificationPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - if self.framework == "tf": - raise ValueError(f"The {self.__class__} is only available in PyTorch.") - requires_backends(self, "vision") - self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING) + self.check_model_type( + TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + if self.framework == "tf" + else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + ) def _sanitize_parameters(self, top_k=None): postprocess_params = {} @@ -77,7 +88,7 @@ class ImageClassificationPipeline(Pipeline): def preprocess(self, image): image = load_image(image) - model_inputs = self.feature_extractor(images=image, return_tensors="pt") + model_inputs = self.feature_extractor(images=image, return_tensors=self.framework) return model_inputs def _forward(self, model_inputs): @@ -87,8 +98,16 @@ class ImageClassificationPipeline(Pipeline): def postprocess(self, model_outputs, top_k=5): if top_k > self.model.config.num_labels: top_k = self.model.config.num_labels - probs = model_outputs.logits.softmax(-1)[0] - scores, ids = probs.topk(top_k) + + if self.framework == "pt": + probs = model_outputs.logits.softmax(-1)[0] + scores, ids = probs.topk(top_k) + elif self.framework == "tf": + probs = tf.nn.softmax(model_outputs.logits, axis=-1)[0] + topk = tf.math.top_k(probs, k=top_k) + scores, ids = topk.values.numpy(), topk.indices.numpy() + else: + raise ValueError(f"Unsupported framework: {self.framework}") scores = scores.tolist() ids = ids.tolist() diff --git a/tests/test_pipelines_image_classification.py b/tests/test_pipelines_image_classification.py index 36c70de3e1..24062b705a 100644 --- a/tests/test_pipelines_image_classification.py +++ b/tests/test_pipelines_image_classification.py @@ -14,7 +14,12 @@ import unittest -from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available +from transformers import ( + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + PreTrainedTokenizer, + is_vision_available, +) from transformers.pipelines import ImageClassificationPipeline, pipeline from transformers.testing_utils import ( is_pipeline_test, @@ -40,9 +45,9 @@ else: @is_pipeline_test @require_vision -@require_torch class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING def get_test_pipeline(self, model, tokenizer, feature_extractor): image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2) @@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ) @require_tf - @unittest.skip("Image classification is not implemented for TF") def test_small_model_tf(self): - pass + small_model = "lysandre/tiny-vit-random" + image_classifier = pipeline("image-classification", model=small_model) + + outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg") + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"}, + {"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"}, + {"score": 0.0014, "label": "trench coat"}, + {"score": 0.0014, "label": "handkerchief, hankie, hanky, hankey"}, + {"score": 0.0014, "label": "baboon"}, + ], + ) + + outputs = image_classifier( + [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + ], + top_k=2, + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"}, + {"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"}, + ], + [ + {"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"}, + {"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"}, + ], + ], + ) def test_custom_tokenizer(self): tokenizer = PreTrainedTokenizer()