Should check that torch TPU is available (#5636)
This commit is contained in:
@@ -34,6 +34,7 @@ from .file_utils import (
|
|||||||
cached_path,
|
cached_path,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
|
is_torch_tpu_available,
|
||||||
)
|
)
|
||||||
from .generation_utils import GenerationMixin
|
from .generation_utils import GenerationMixin
|
||||||
|
|
||||||
@@ -794,7 +795,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
}
|
}
|
||||||
return model, loading_info
|
return model, loading_info
|
||||||
|
|
||||||
if hasattr(config, "xla_device") and config.xla_device:
|
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
||||||
|
|||||||
Reference in New Issue
Block a user