Fix pipeline tests - torch imports (#31227)
* Fix pipeline tests - torch imports * Frameowrk dependant float conversion
This commit is contained in:
@@ -202,7 +202,12 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
function_to_apply = ClassificationFunction.NONE
|
function_to_apply = ClassificationFunction.NONE
|
||||||
|
|
||||||
outputs = model_outputs["logits"][0]
|
outputs = model_outputs["logits"][0]
|
||||||
outputs = outputs.float().numpy()
|
|
||||||
|
if self.framework == "pt":
|
||||||
|
# To enable using fp16 and bf16
|
||||||
|
outputs = outputs.float().numpy()
|
||||||
|
else:
|
||||||
|
outputs = outputs.numpy()
|
||||||
|
|
||||||
if function_to_apply == ClassificationFunction.SIGMOID:
|
if function_to_apply == ClassificationFunction.SIGMOID:
|
||||||
scores = sigmoid(outputs)
|
scores = sigmoid(outputs)
|
||||||
|
|||||||
@@ -14,8 +14,6 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
@@ -24,6 +22,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
|
is_torch_available,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -36,6 +35,10 @@ from transformers.testing_utils import (
|
|||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# These 2 model types require different inputs than those of the usual text models.
|
# These 2 model types require different inputs than those of the usual text models.
|
||||||
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
|
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user