Regression pipeline device (#22190)
* Fix regression in pipeline when device=-1 is passed * Add regression test
This commit is contained in:
@@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat):
|
||||
self.modelcard = modelcard
|
||||
self.framework = framework
|
||||
|
||||
if self.framework == "pt" and device is not None:
|
||||
self.model = self.model.to(device=device)
|
||||
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
|
||||
self.model.to(device)
|
||||
|
||||
if device is None:
|
||||
# `accelerate` device map
|
||||
|
||||
@@ -484,6 +484,14 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
outputs = list(dataset)
|
||||
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
|
||||
|
||||
def test_pipeline_negative_device(self):
|
||||
# To avoid regressing, pipeline used to accept device=-1
|
||||
classifier = pipeline("text-generation", "hf-internal-testing/tiny-random-bert", device=-1)
|
||||
|
||||
expected_output = [{"generated_text": ANY(str)}]
|
||||
actual_output = classifier("Test input.")
|
||||
self.assertEqual(expected_output, actual_output)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_default_pipelines_pt(self):
|
||||
|
||||
Reference in New Issue
Block a user