Fix all is_torch_tpu_available issues (#17936)
* Fix all is_torch_tpu_available
This commit is contained in:
@@ -20,7 +20,7 @@ from transformers import Trainer, is_torch_tpu_available
|
|||||||
from transformers.trainer_utils import PredictionOutput
|
from transformers.trainer_utils import PredictionOutput
|
||||||
|
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from transformers import Seq2SeqTrainer, is_torch_tpu_available
|
|||||||
from transformers.trainer_utils import PredictionOutput
|
from transformers.trainer_utils import PredictionOutput
|
||||||
|
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from transformers.trainer_utils import PredictionOutput
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from .benchmark_args_utils import BenchmarkArguments
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -467,7 +467,7 @@ def require_torch_tpu(test_case):
|
|||||||
"""
|
"""
|
||||||
Decorator marking a test that requires a TPU (in PyTorch).
|
Decorator marking a test that requires a TPU (in PyTorch).
|
||||||
"""
|
"""
|
||||||
return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case)
|
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ if version.parse(torch.__version__) >= version.parse("1.10"):
|
|||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
import torch_xla.distributed.parallel_loader as pl
|
import torch_xla.distributed.parallel_loader as pl
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_
|
|||||||
if is_training_run_on_sagemaker():
|
if is_training_run_on_sagemaker():
|
||||||
logging.add_handler(StreamHandler(sys.stdout))
|
logging.add_handler(StreamHandler(sys.stdout))
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
|
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
|
||||||
|
|||||||
@@ -307,7 +307,7 @@ def is_main_process(local_rank):
|
|||||||
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
|
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
|
||||||
`local_rank`.
|
`local_rank`.
|
||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=True):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
return xm.get_ordinal() == 0
|
return xm.get_ordinal() == 0
|
||||||
@@ -318,7 +318,7 @@ def total_processes_number(local_rank):
|
|||||||
"""
|
"""
|
||||||
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
|
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
|
||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=True):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
return xm.xrt_world_size()
|
return xm.xrt_world_size()
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -396,19 +396,22 @@ def is_ftfy_available():
|
|||||||
return _ftfy_available
|
return _ftfy_available
|
||||||
|
|
||||||
|
|
||||||
def is_torch_tpu_available():
|
def is_torch_tpu_available(check_device=True):
|
||||||
|
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
|
||||||
if not _torch_available:
|
if not _torch_available:
|
||||||
return False
|
return False
|
||||||
if importlib.util.find_spec("torch_xla") is None:
|
if importlib.util.find_spec("torch_xla") is not None:
|
||||||
return False
|
if check_device:
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
||||||
try:
|
try:
|
||||||
xm.xla_device()
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
_ = xm.xla_device()
|
||||||
return True
|
return True
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
return False
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_torchdynamo_available():
|
def is_torchdynamo_available():
|
||||||
|
|||||||
Reference in New Issue
Block a user