From a958c4a801a0db655981678b050eaa855289405c Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Tue, 25 Jun 2024 18:14:39 +0800 Subject: [PATCH] fix output data type of image classification (#31444) * fix output data type of image classification * add tests for low-precision pipeline * add bf16 pipeline tests * fix bf16 tests * Update tests/pipelines/test_pipelines_image_classification.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix import * fix import torch * fix style --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../pipelines/image_classification.py | 7 ++++- .../test_pipelines_image_classification.py | 28 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index 62793c252a..bfa005f06b 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -23,6 +23,8 @@ if is_tf_available(): from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES if is_torch_available(): + import torch + from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES logger = logging.get_logger(__name__) @@ -180,7 +182,10 @@ class ImageClassificationPipeline(Pipeline): top_k = self.model.config.num_labels outputs = model_outputs["logits"][0] - outputs = outputs.numpy() + if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16): + outputs = outputs.to(torch.float32).numpy() + else: + outputs = outputs.numpy() if function_to_apply == ClassificationFunction.SIGMOID: scores = sigmoid(outputs) diff --git a/tests/pipelines/test_pipelines_image_classification.py b/tests/pipelines/test_pipelines_image_classification.py index 9f6a8adfd1..3e93f31d18 100644 --- a/tests/pipelines/test_pipelines_image_classification.py +++ b/tests/pipelines/test_pipelines_image_classification.py @@ -18,6 +18,7 @@ from transformers import ( MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizerBase, + is_torch_available, is_vision_available, ) from transformers.pipelines import ImageClassificationPipeline, pipeline @@ -34,6 +35,9 @@ from transformers.testing_utils import ( from .test_pipelines_common import ANY +if is_torch_available(): + import torch + if is_vision_available(): from PIL import Image else: @@ -177,6 +181,30 @@ class ImageClassificationPipelineTests(unittest.TestCase): self.assertIs(image_classifier.tokenizer, tokenizer) + @require_torch + def test_torch_float16_pipeline(self): + image_classifier = pipeline( + "image-classification", model="hf-internal-testing/tiny-random-vit", torch_dtype=torch.float16 + ) + outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg") + + self.assertEqual( + nested_simplify(outputs, decimals=3), + [{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}], + ) + + @require_torch + def test_torch_bfloat16_pipeline(self): + image_classifier = pipeline( + "image-classification", model="hf-internal-testing/tiny-random-vit", torch_dtype=torch.bfloat16 + ) + outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg") + + self.assertEqual( + nested_simplify(outputs, decimals=3), + [{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}], + ) + @slow @require_torch def test_perceiver(self):