Regression pipeline device (#22190)
* Fix regression in pipeline when device=-1 is passed * Add regression test
This commit is contained in:
@@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat):
|
||||
self.modelcard = modelcard
|
||||
self.framework = framework
|
||||
|
||||
if self.framework == "pt" and device is not None:
|
||||
self.model = self.model.to(device=device)
|
||||
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
|
||||
self.model.to(device)
|
||||
|
||||
if device is None:
|
||||
# `accelerate` device map
|
||||
|
||||
Reference in New Issue
Block a user