From 2cb2ea3fa1a6819893f29f44737c0f83899a9e57 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 18 May 2022 16:06:24 +0200 Subject: [PATCH] Accepting real pytorch device as arguments. (#17318) * Accepting real pytorch device as arguments. * is_torch_available. --- src/transformers/pipelines/base.py | 7 +++++-- .../test_pipelines_text_classification.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index a33089547f..4712eaba57 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -693,7 +693,7 @@ PIPELINE_INIT_ARGS = r""" Reference to the object in charge of parsing supplied pipeline parameters. device (`int`, *optional*, defaults to -1): Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on - the associated CUDA device id. + the associated CUDA device id. You can pass native `torch.device` too. binary_output (`bool`, *optional*, defaults to `False`): Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text. """ @@ -750,7 +750,10 @@ class Pipeline(_ScikitCompat): self.feature_extractor = feature_extractor self.modelcard = modelcard self.framework = framework - self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}") + if is_torch_available() and isinstance(device, torch.device): + self.device = device + else: + self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}") self.binary_output = binary_output # Special handling diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py index 39deed9bee..4cb72be4c2 100644 --- a/tests/pipelines/test_pipelines_text_classification.py +++ b/tests/pipelines/test_pipelines_text_classification.py @@ -39,6 +39,20 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC outputs = text_classifier("This is great !") self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + @require_torch + def test_accepts_torch_device(self): + import torch + + text_classifier = pipeline( + task="text-classification", + model="hf-internal-testing/tiny-random-distilbert", + framework="pt", + device=torch.device("cpu"), + ) + + outputs = text_classifier("This is great !") + self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + @require_tf def test_small_model_tf(self): text_classifier = pipeline(