Accepting real pytorch device as arguments. (#17318)

* Accepting real pytorch device as arguments.

* is_torch_available.
This commit is contained in:
Nicolas Patry
2022-05-18 16:06:24 +02:00
committed by GitHub
parent 1c9d1f4ca8
commit 2cb2ea3fa1
2 changed files with 19 additions and 2 deletions

View File

@@ -39,6 +39,20 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_torch
def test_accepts_torch_device(self):
import torch
text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch.device("cpu"),
)
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_tf
def test_small_model_tf(self):
text_classifier = pipeline(