Fix pipeline tests - torch imports (#31227)
* Fix pipeline tests - torch imports * Frameowrk dependant float conversion
This commit is contained in:
@@ -202,7 +202,12 @@ class TextClassificationPipeline(Pipeline):
|
||||
function_to_apply = ClassificationFunction.NONE
|
||||
|
||||
outputs = model_outputs["logits"][0]
|
||||
|
||||
if self.framework == "pt":
|
||||
# To enable using fp16 and bf16
|
||||
outputs = outputs.float().numpy()
|
||||
else:
|
||||
outputs = outputs.numpy()
|
||||
|
||||
if function_to_apply == ClassificationFunction.SIGMOID:
|
||||
scores = sigmoid(outputs)
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
@@ -24,6 +22,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
@@ -36,6 +35,10 @@ from transformers.testing_utils import (
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
# These 2 model types require different inputs than those of the usual text models.
|
||||
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user