[Tests, GPU, SLOW] fix a bunch of GPU hardcoded tests in Pytorch (#4468)
* fix gpu slow tests in pytorch * change model to device syntax
This commit is contained in:
committed by
GitHub
parent
5856999a9f
commit
aa925a52fa
@@ -770,7 +770,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
||||
model = model.to(xm.xla_device())
|
||||
model.to(xm.xla_device())
|
||||
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user