Regression pipeline device (#22190)

* Fix regression in pipeline when device=-1 is passed

* Add regression test
This commit is contained in:
Sylvain Gugger
2023-03-15 14:13:38 -04:00
committed by GitHub
parent 737681477c
commit 42ad693b7b
2 changed files with 10 additions and 2 deletions

View File

@@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat):
self.modelcard = modelcard
self.framework = framework
if self.framework == "pt" and device is not None:
self.model = self.model.to(device=device)
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
# `accelerate` device map