From b25f7802de2749a5f8c3430437eceabf9e6384b8 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 9 Jul 2020 13:54:32 -0400 Subject: [PATCH] Should check that torch TPU is available (#5636) --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5b000c8125..c532b94b43 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -34,6 +34,7 @@ from .file_utils import ( cached_path, hf_bucket_url, is_remote_url, + is_torch_tpu_available, ) from .generation_utils import GenerationMixin @@ -794,7 +795,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): } return model, loading_info - if hasattr(config, "xla_device") and config.xla_device: + if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available(): import torch_xla.core.xla_model as xm model = xm.send_cpu_data_to_device(model, xm.xla_device())