check for tpu availability in save_pretrained (#7699)
Added is_torch_tpu_available() to the condition for saving a model as xla model. "xla_device" property of config can also be True on a non-xla device, when loading a checkpointthat was trained on xla before. Resolves #7695
This commit is contained in:
@@ -716,7 +716,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
|
||||
if getattr(self.config, "xla_device", False):
|
||||
if getattr(self.config, "xla_device", False) and is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if xm.is_master_ordinal():
|
||||
|
||||
Reference in New Issue
Block a user