[Tests, GPU, SLOW] fix a bunch of GPU hardcoded tests in Pytorch (#4468)

* fix gpu slow tests in pytorch

* change model to device syntax
This commit is contained in:
Patrick von Platen
2020-05-19 21:35:04 +02:00
committed by GitHub
parent 5856999a9f
commit aa925a52fa
11 changed files with 28 additions and 11 deletions

View File

@@ -770,7 +770,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
import torch_xla.core.xla_model as xm
model = xm.send_cpu_data_to_device(model, xm.xla_device())
model = model.to(xm.xla_device())
model.to(xm.xla_device())
return model