From 6b22a8f2d8ded6b2a681e1c078b0e1abf27d045c Mon Sep 17 00:00:00 2001 From: Chujie Zheng Date: Tue, 4 Jun 2024 03:20:48 -0700 Subject: [PATCH] fix bf16 issue in text classification pipeline (#30996) * fix logits dtype * Add bf16/fp16 tests for text_classification pipeline * Update test_pipelines_text_classification.py * fix * fix --- .../pipelines/text_classification.py | 2 +- .../test_pipelines_text_classification.py | 39 ++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py index 6521da098d..bc763c1614 100644 --- a/src/transformers/pipelines/text_classification.py +++ b/src/transformers/pipelines/text_classification.py @@ -202,7 +202,7 @@ class TextClassificationPipeline(Pipeline): function_to_apply = ClassificationFunction.NONE outputs = model_outputs["logits"][0] - outputs = outputs.numpy() + outputs = outputs.float().numpy() if function_to_apply == ClassificationFunction.SIGMOID: scores = sigmoid(outputs) diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py index 7a33a41c06..6e40f33fbb 100644 --- a/tests/pipelines/test_pipelines_text_classification.py +++ b/tests/pipelines/test_pipelines_text_classification.py @@ -14,13 +14,24 @@ import unittest +import torch + from transformers import ( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TextClassificationPipeline, pipeline, ) -from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow, torch_device +from transformers.testing_utils import ( + is_pipeline_test, + nested_simplify, + require_tf, + require_torch, + require_torch_bf16, + require_torch_fp16, + slow, + torch_device, +) from .test_pipelines_common import ANY @@ -106,6 +117,32 @@ class TextClassificationPipelineTests(unittest.TestCase): outputs = text_classifier("This is great !") self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + @require_torch_fp16 + def test_accepts_torch_fp16(self): + text_classifier = pipeline( + task="text-classification", + model="hf-internal-testing/tiny-random-distilbert", + framework="pt", + device=torch_device, + torch_dtype=torch.float16, + ) + + outputs = text_classifier("This is great !") + self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + + @require_torch_bf16 + def test_accepts_torch_bf16(self): + text_classifier = pipeline( + task="text-classification", + model="hf-internal-testing/tiny-random-distilbert", + framework="pt", + device=torch_device, + torch_dtype=torch.bfloat16, + ) + + outputs = text_classifier("This is great !") + self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + @require_tf def test_small_model_tf(self): text_classifier = pipeline(