Fix pipeline tests - torch imports (#31227)

* Fix pipeline tests - torch imports

* Frameowrk dependant float conversion
This commit is contained in:
amyeroberts
2024-06-04 12:30:23 +01:00
committed by GitHub
parent 6b22a8f2d8
commit 4ba66fdb4c
2 changed files with 11 additions and 3 deletions

View File

@@ -14,8 +14,6 @@
import unittest
import torch
from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
@@ -24,6 +22,7 @@ from transformers import (
)
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
@@ -36,6 +35,10 @@ from transformers.testing_utils import (
from .test_pipelines_common import ANY
if is_torch_available():
import torch
# These 2 model types require different inputs than those of the usual text models.
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}