device agnostic pipelines testing (#27129)
* device agnostic pipelines testing * pass torch_device
This commit is contained in:
@@ -20,7 +20,7 @@ from transformers import (
|
||||
TextClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow, torch_device
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
@@ -96,13 +96,11 @@ class TextClassificationPipelineTests(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
def test_accepts_torch_device(self):
|
||||
import torch
|
||||
|
||||
text_classifier = pipeline(
|
||||
task="text-classification",
|
||||
model="hf-internal-testing/tiny-random-distilbert",
|
||||
framework="pt",
|
||||
device=torch.device("cpu"),
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
outputs = text_classifier("This is great !")
|
||||
|
||||
Reference in New Issue
Block a user