Accepting real pytorch device as arguments. (#17318)
* Accepting real pytorch device as arguments. * is_torch_available.
This commit is contained in:
@@ -39,6 +39,20 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
|
||||
outputs = text_classifier("This is great !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
@require_torch
|
||||
def test_accepts_torch_device(self):
|
||||
import torch
|
||||
|
||||
text_classifier = pipeline(
|
||||
task="text-classification",
|
||||
model="hf-internal-testing/tiny-random-distilbert",
|
||||
framework="pt",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
outputs = text_classifier("This is great !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
text_classifier = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user