Regression pipeline device (#22190)
* Fix regression in pipeline when device=-1 is passed * Add regression test
This commit is contained in:
@@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat):
|
|||||||
self.modelcard = modelcard
|
self.modelcard = modelcard
|
||||||
self.framework = framework
|
self.framework = framework
|
||||||
|
|
||||||
if self.framework == "pt" and device is not None:
|
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
|
||||||
self.model = self.model.to(device=device)
|
self.model.to(device)
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
# `accelerate` device map
|
# `accelerate` device map
|
||||||
|
|||||||
@@ -484,6 +484,14 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
outputs = list(dataset)
|
outputs = list(dataset)
|
||||||
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
|
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
|
||||||
|
|
||||||
|
def test_pipeline_negative_device(self):
|
||||||
|
# To avoid regressing, pipeline used to accept device=-1
|
||||||
|
classifier = pipeline("text-generation", "hf-internal-testing/tiny-random-bert", device=-1)
|
||||||
|
|
||||||
|
expected_output = [{"generated_text": ANY(str)}]
|
||||||
|
actual_output = classifier("Test input.")
|
||||||
|
self.assertEqual(expected_output, actual_output)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_load_default_pipelines_pt(self):
|
def test_load_default_pipelines_pt(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user