check for tpu availability in save_pretrained (#7699)

Added is_torch_tpu_available() to the condition
for saving a model as xla model. "xla_device"
property of config can also be True on a non-xla
device, when loading a checkpointthat was trained
on xla before.

Resolves #7695
This commit is contained in:
fteufel
2020-10-12 10:10:17 +02:00
committed by GitHub
parent 13c1857718
commit 2f34bcf3e7

View File

@@ -716,7 +716,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
if getattr(self.config, "xla_device", False):
if getattr(self.config, "xla_device", False) and is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if xm.is_master_ordinal():