From 4ba66fdb4c26cc2138c9966193d17944e36af8f6 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 4 Jun 2024 12:30:23 +0100 Subject: [PATCH] Fix pipeline tests - torch imports (#31227) * Fix pipeline tests - torch imports * Frameowrk dependant float conversion --- src/transformers/pipelines/text_classification.py | 7 ++++++- tests/pipelines/test_pipelines_text_classification.py | 7 +++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py index bc763c1614..21ca70c2ac 100644 --- a/src/transformers/pipelines/text_classification.py +++ b/src/transformers/pipelines/text_classification.py @@ -202,7 +202,12 @@ class TextClassificationPipeline(Pipeline): function_to_apply = ClassificationFunction.NONE outputs = model_outputs["logits"][0] - outputs = outputs.float().numpy() + + if self.framework == "pt": + # To enable using fp16 and bf16 + outputs = outputs.float().numpy() + else: + outputs = outputs.numpy() if function_to_apply == ClassificationFunction.SIGMOID: scores = sigmoid(outputs) diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py index 6e40f33fbb..63adfc45a0 100644 --- a/tests/pipelines/test_pipelines_text_classification.py +++ b/tests/pipelines/test_pipelines_text_classification.py @@ -14,8 +14,6 @@ import unittest -import torch - from transformers import ( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, @@ -24,6 +22,7 @@ from transformers import ( ) from transformers.testing_utils import ( is_pipeline_test, + is_torch_available, nested_simplify, require_tf, require_torch, @@ -36,6 +35,10 @@ from transformers.testing_utils import ( from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + # These 2 model types require different inputs than those of the usual text models. _TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}