From 42ad693b7b3f12421834861f84916f6073643e41 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 15 Mar 2023 14:13:38 -0400 Subject: [PATCH] Regression pipeline device (#22190) * Fix regression in pipeline when device=-1 is passed * Add regression test --- src/transformers/pipelines/base.py | 4 ++-- tests/pipelines/test_pipelines_common.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 3dca2d33d1..cdb597ef96 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat): self.modelcard = modelcard self.framework = framework - if self.framework == "pt" and device is not None: - self.model = self.model.to(device=device) + if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0): + self.model.to(device) if device is None: # `accelerate` device map diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f43f439ac2..1a3ddace28 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -484,6 +484,14 @@ class PipelineUtilsTest(unittest.TestCase): outputs = list(dataset) self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]]) + def test_pipeline_negative_device(self): + # To avoid regressing, pipeline used to accept device=-1 + classifier = pipeline("text-generation", "hf-internal-testing/tiny-random-bert", device=-1) + + expected_output = [{"generated_text": ANY(str)}] + actual_output = classifier("Test input.") + self.assertEqual(expected_output, actual_output) + @slow @require_torch def test_load_default_pipelines_pt(self):