From 7c4c6f60849fa8c8fa5904e2c6a6e1a21c981f56 Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Wed, 29 Jun 2022 11:03:33 -0400 Subject: [PATCH] Fix all is_torch_tpu_available issues (#17936) * Fix all is_torch_tpu_available --- .../pytorch/question-answering/trainer_qa.py | 2 +- .../question-answering/trainer_seq2seq_qa.py | 2 +- .../quantization-qdqbert/trainer_quant_qa.py | 2 +- src/transformers/benchmark/benchmark_args.py | 2 +- src/transformers/testing_utils.py | 2 +- src/transformers/trainer.py | 2 +- src/transformers/trainer_pt_utils.py | 2 +- src/transformers/trainer_utils.py | 4 ++-- src/transformers/training_args.py | 2 +- src/transformers/utils/import_utils.py | 21 +++++++++++-------- 10 files changed, 22 insertions(+), 19 deletions(-) diff --git a/examples/pytorch/question-answering/trainer_qa.py b/examples/pytorch/question-answering/trainer_qa.py index 7f98eba236..59d7a084c1 100644 --- a/examples/pytorch/question-answering/trainer_qa.py +++ b/examples/pytorch/question-answering/trainer_qa.py @@ -20,7 +20,7 @@ from transformers import Trainer, is_torch_tpu_available from transformers.trainer_utils import PredictionOutput -if is_torch_tpu_available(): +if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py index 6a5f6da941..6ad66aeec5 100644 --- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -23,7 +23,7 @@ from transformers import Seq2SeqTrainer, is_torch_tpu_available from transformers.trainer_utils import PredictionOutput -if is_torch_tpu_available(): +if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met diff --git a/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py b/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py index b23edb6d51..ef0d93a7e3 100644 --- a/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py +++ b/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py @@ -30,7 +30,7 @@ from transformers.trainer_utils import PredictionOutput logger = logging.getLogger(__name__) -if is_torch_tpu_available(): +if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py index 57af2481ef..2d759ac342 100644 --- a/src/transformers/benchmark/benchmark_args.py +++ b/src/transformers/benchmark/benchmark_args.py @@ -24,7 +24,7 @@ from .benchmark_args_utils import BenchmarkArguments if is_torch_available(): import torch -if is_torch_tpu_available(): +if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 569b117667..1a71e9d840 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -467,7 +467,7 @@ def require_torch_tpu(test_case): """ Decorator marking a test that requires a TPU (in PyTorch). """ - return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case) + return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case) if is_torch_available(): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5b0945703c..0a57a769db 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -171,7 +171,7 @@ if version.parse(torch.__version__) >= version.parse("1.10"): if is_datasets_available(): import datasets -if is_torch_tpu_available(): +if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index f669e6f32a..8cb00f3ad1 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -43,7 +43,7 @@ from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_ if is_training_run_on_sagemaker(): logging.add_handler(StreamHandler(sys.stdout)) -if is_torch_tpu_available(): +if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm # this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 46fd0cdd05..daa745c365 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -307,7 +307,7 @@ def is_main_process(local_rank): Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on `local_rank`. """ - if is_torch_tpu_available(): + if is_torch_tpu_available(check_device=True): import torch_xla.core.xla_model as xm return xm.get_ordinal() == 0 @@ -318,7 +318,7 @@ def total_processes_number(local_rank): """ Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs. """ - if is_torch_tpu_available(): + if is_torch_tpu_available(check_device=True): import torch_xla.core.xla_model as xm return xm.xrt_world_size() diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 36448cfd54..5c370cf072 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -52,7 +52,7 @@ if is_torch_available(): import torch import torch.distributed as dist -if is_torch_tpu_available(): +if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 270f6c06c0..275477a4ba 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -396,19 +396,22 @@ def is_ftfy_available(): return _ftfy_available -def is_torch_tpu_available(): +def is_torch_tpu_available(check_device=True): + "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" if not _torch_available: return False - if importlib.util.find_spec("torch_xla") is None: - return False - import torch_xla.core.xla_model as xm + if importlib.util.find_spec("torch_xla") is not None: + if check_device: + # We need to check if `xla_device` can be found, will raise a RuntimeError if not + try: + import torch_xla.core.xla_model as xm - # We need to check if `xla_device` can be found, will raise a RuntimeError if not - try: - xm.xla_device() + _ = xm.xla_device() + return True + except RuntimeError: + return False return True - except RuntimeError: - return False + return False def is_torchdynamo_available():