check for tpu availability in save_pretrained (#7699)
Added is_torch_tpu_available() to the condition for saving a model as xla model. "xla_device" property of config can also be True on a non-xla device, when loading a checkpointthat was trained on xla before. Resolves #7695
This commit is contained in:
@@ -716,7 +716,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||||
|
|
||||||
if getattr(self.config, "xla_device", False):
|
if getattr(self.config, "xla_device", False) and is_torch_tpu_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
if xm.is_master_ordinal():
|
if xm.is_master_ordinal():
|
||||||
|
|||||||
Reference in New Issue
Block a user