Fix pipeline tests - torch imports (#31227)

* Fix pipeline tests - torch imports

* Frameowrk dependant float conversion
This commit is contained in:
amyeroberts
2024-06-04 12:30:23 +01:00
committed by GitHub
parent 6b22a8f2d8
commit 4ba66fdb4c
2 changed files with 11 additions and 3 deletions

View File

@@ -202,7 +202,12 @@ class TextClassificationPipeline(Pipeline):
function_to_apply = ClassificationFunction.NONE
outputs = model_outputs["logits"][0]
if self.framework == "pt":
# To enable using fp16 and bf16
outputs = outputs.float().numpy()
else:
outputs = outputs.numpy()
if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)

View File

@@ -14,8 +14,6 @@
import unittest
import torch
from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
@@ -24,6 +22,7 @@ from transformers import (
)
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
@@ -36,6 +35,10 @@ from transformers.testing_utils import (
from .test_pipelines_common import ANY
if is_torch_available():
import torch
# These 2 model types require different inputs than those of the usual text models.
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}