device agnostic pipelines testing (#27129)

* device agnostic pipelines testing

* pass torch_device
This commit is contained in:
Hz, Ji
2023-10-31 22:46:31 +08:00
committed by GitHub
parent 08fadc8085
commit f53041a753
10 changed files with 64 additions and 58 deletions

View File

@@ -30,8 +30,9 @@ from transformers.testing_utils import (
nested_simplify,
require_tf,
require_torch,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
from .test_pipelines_common import ANY
@@ -391,13 +392,13 @@ class TokenClassificationPipelineTests(unittest.TestCase):
],
)
@require_torch_gpu
@require_torch_accelerator
@slow
def test_gpu(self):
def test_accelerator(self):
sentence = "This is dummy sentence"
ner = pipeline(
"token-classification",
device=0,
device=torch_device,
aggregation_strategy=AggregationStrategy.SIMPLE,
)