Make torch xla available on GPU (#29334)

* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------

Co-authored-by: anw90 <ang868@gmail.com>
Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
Yitong Huang
2024-03-11 22:07:16 +08:00
committed by GitHub
parent 9a3f4d4daf
commit 873d9bb3cc
25 changed files with 120 additions and 77 deletions

View File

@@ -452,7 +452,7 @@ Dekorateure werden verwendet, um die Anforderungen von Tests in Bezug auf CPU/GP
- `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich - `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich
- `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs - `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs
- `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs - `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs
- `require_torch_tpu` - wie `require_torch` plus erfordert mindestens 1 TPU - `require_torch_xla` - wie `require_torch` plus erfordert mindestens 1 TPU
Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen: Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen:

View File

@@ -451,7 +451,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
- `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs - `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs
- `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs - `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs
- `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs - `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs
- `require_torch_tpu` - as `require_torch` plus requires at least 1 TPU - `require_torch_xla` - as `require_torch` plus requires at least 1 TPU
Let's depict the GPU requirements in the following table: Let's depict the GPU requirements in the following table:

View File

@@ -424,7 +424,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。 - `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。
- `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。 - `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。
- `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。 - `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。
- `require_torch_tpu` - `require_torch` に加えて、少なくとも1つのTPUが必要です。 - `require_torch_xla` - `require_torch` に加えて、少なくとも1つのTPUが必要です。
以下の表にGPUの要件を示します 以下の表にGPUの要件を示します

View File

@@ -452,7 +452,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다. - `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다.
- `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다. - `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다.
- `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다. - `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다.
- `require_torch_tpu` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다. - `require_torch_xla` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.
GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ: GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ:

View File

@@ -32,7 +32,7 @@ from transformers.optimization import (
) )
from transformers.trainer_pt_utils import get_tpu_sampler from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available from transformers.utils import is_torch_xla_available
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -135,7 +135,7 @@ class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None return None
elif is_torch_tpu_available(): elif is_torch_xla_available():
return get_tpu_sampler(self.train_dataset) return get_tpu_sampler(self.train_dataset)
else: else:
if self.args.sortish_sampler: if self.args.sortish_sampler:

View File

@@ -46,7 +46,7 @@ from transformers import (
Trainer, Trainer,
TrainingArguments, TrainingArguments,
default_data_collator, default_data_collator,
is_torch_tpu_available, is_torch_xla_available,
set_seed, set_seed,
) )
from transformers.testing_utils import CaptureLogger from transformers.testing_utils import CaptureLogger
@@ -602,9 +602,9 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it. # Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator, data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available() if training_args.do_eval and not is_torch_xla_available()
else None, else None,
) )

View File

@@ -45,7 +45,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
is_torch_tpu_available, is_torch_xla_available,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
@@ -620,9 +620,9 @@ def main():
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available() if training_args.do_eval and not is_torch_xla_available()
else None, else None,
) )

View File

@@ -21,7 +21,7 @@ import sys
from time import time from time import time
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import TestCasePlus, require_torch_tpu from transformers.testing_utils import TestCasePlus, require_torch_xla
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@@ -44,7 +44,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
@require_torch_tpu @require_torch_xla
class TorchXLAExamplesTests(TestCasePlus): class TorchXLAExamplesTests(TestCasePlus):
def test_run_glue(self): def test_run_glue(self):
import xla_spawn import xla_spawn

View File

@@ -18,11 +18,11 @@ A subclass of `Trainer` specific to Question-Answering tasks
import math import math
import time import time
from transformers import Trainer, is_torch_tpu_available from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met

View File

@@ -21,11 +21,11 @@ from typing import Dict, List, Optional
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available from transformers import Seq2SeqTrainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met

View File

@@ -24,13 +24,13 @@ import quant_trainer
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import Trainer, is_torch_tpu_available from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput from transformers.trainer_utils import PredictionOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met

View File

@@ -1093,6 +1093,7 @@ _import_structure = {
"is_torch_npu_available", "is_torch_npu_available",
"is_torch_tpu_available", "is_torch_tpu_available",
"is_torchvision_available", "is_torchvision_available",
"is_torch_xla_available",
"is_torch_xpu_available", "is_torch_xpu_available",
"is_vision_available", "is_vision_available",
"logging", "logging",
@@ -5897,6 +5898,7 @@ if TYPE_CHECKING:
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchvision_available, is_torchvision_available,
is_vision_available, is_vision_available,

View File

@@ -20,7 +20,7 @@ from typing import Tuple
from ..utils import ( from ..utils import (
cached_property, cached_property,
is_torch_available, is_torch_available,
is_torch_tpu_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
logging, logging,
requires_backends, requires_backends,
@@ -31,7 +31,7 @@ from .benchmark_args_utils import BenchmarkArguments
if is_torch_available(): if is_torch_available():
import torch import torch
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
@@ -88,7 +88,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
if not self.cuda: if not self.cuda:
device = torch.device("cpu") device = torch.device("cpu")
n_gpu = 0 n_gpu = 0
elif is_torch_tpu_available(): elif is_torch_xla_available():
device = xm.xla_device() device = xm.xla_device()
n_gpu = 0 n_gpu = 0
elif is_torch_xpu_available(): elif is_torch_xpu_available():
@@ -101,7 +101,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
@property @property
def is_tpu(self): def is_tpu(self):
return is_torch_tpu_available() and self.tpu return is_torch_xla_available() and self.tpu
@property @property
def device_idx(self) -> int: def device_idx(self) -> int:

View File

@@ -121,7 +121,7 @@ from .utils import (
is_torch_fx_proxy, is_torch_fx_proxy,
is_torch_mps_available, is_torch_mps_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_xla_available,
is_torchaudio_available, is_torchaudio_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
is_vision_available, is_vision_available,

View File

@@ -72,7 +72,7 @@ if TYPE_CHECKING and _has_neptune:
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402 from ..training_args import ParallelMode # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402 from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402
# Integration functions: # Integration functions:
@@ -752,7 +752,7 @@ class WandbCallback(TrainerCallback):
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false") _watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"): if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer") self._wandb.run._label(code="transformers_trainer")

View File

@@ -14,11 +14,11 @@
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from ..utils import is_torch_tpu_available from ..utils import is_torch_xla_available
def tpu_spmd_dataloader(dataloader: DataLoader): def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_tpu_available(): if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.parallel_loader as pl
assert isinstance( assert isinstance(

View File

@@ -84,7 +84,7 @@ from .utils import (
is_remote_url, is_remote_url,
is_safetensors_available, is_safetensors_available,
is_torch_sdpa_available, is_torch_sdpa_available,
is_torch_tpu_available, is_torch_xla_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
strtobool, strtobool,
@@ -246,10 +246,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# Adding fix for https://github.com/pytorch/xla/issues/4152 # Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1 # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
# NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available(): if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
return torch.bfloat16 return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available(): if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
if t.dtype == torch.float: if t.dtype == torch.float:
return torch.bfloat16 return torch.bfloat16
if t.dtype == torch.double: if t.dtype == torch.double:

View File

@@ -19,7 +19,7 @@ from packaging import version
from safetensors.torch import storage_ptr, storage_size from safetensors.torch import storage_ptr, storage_size
from torch import nn from torch import nn
from .utils import is_torch_tpu_available, logging from .utils import is_torch_xla_available, logging
ALL_LAYERNORM_LAYERS = [nn.LayerNorm] ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
@@ -282,7 +282,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id. non-overlapping lifetimes may have the same id.
""" """
if tensor.device.type == "xla" and is_torch_tpu_available(): if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage # NOTE: xla tensors dont have storage
# use some other unique id to distinguish. # use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's # this is a XLA tensor, it must be created using torch_xla's

View File

@@ -115,7 +115,7 @@ from .utils import (
is_torch_sdpa_available, is_torch_sdpa_available,
is_torch_tensorrt_fx_available, is_torch_tensorrt_fx_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdynamo_available, is_torchdynamo_available,
@@ -733,11 +733,11 @@ def require_torch_up_to_2_accelerators(test_case):
(test_case) (test_case)
def require_torch_tpu(test_case): def require_torch_xla(test_case):
""" """
Decorator marking a test that requires a TPU (in PyTorch). Decorator marking a test that requires TorchXLA (in PyTorch).
""" """
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case) return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
def require_torch_neuroncore(test_case): def require_torch_neuroncore(test_case):

View File

@@ -149,7 +149,7 @@ from .utils import (
is_torch_compile_available, is_torch_compile_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tpu_available, is_torch_xla_available,
logging, logging,
strtobool, strtobool,
) )
@@ -170,7 +170,7 @@ if is_apex_available():
if is_datasets_available(): if is_datasets_available():
import datasets import datasets
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs import torch_xla.distributed.spmd as xs
@@ -508,7 +508,7 @@ class Trainer:
"Passing a `model_init` is incompatible with providing the `optimizers` argument. " "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
) )
if is_torch_tpu_available() and self.optimizer is not None: if is_torch_xla_available() and self.optimizer is not None:
for param in self.model.parameters(): for param in self.model.parameters():
model_device = param.device model_device = param.device
break break
@@ -856,7 +856,7 @@ class Trainer:
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code # Deprecated code
if self.args.use_legacy_prediction_loop: if self.args.use_legacy_prediction_loop:
if is_torch_tpu_available(): if is_torch_xla_available():
return SequentialDistributedSampler( return SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
) )
@@ -1975,7 +1975,7 @@ class Trainer:
if ( if (
args.logging_nan_inf_filter args.logging_nan_inf_filter
and not is_torch_tpu_available() and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
): ):
# if loss is nan or inf simply add the average of previous logged losses # if loss is nan or inf simply add the average of previous logged losses
@@ -2027,7 +2027,7 @@ class Trainer:
if hasattr(grad_norm, "item"): if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item() grad_norm = grad_norm.item()
else: else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None grad_norm = _grad_norm
# Optimizer step # Optimizer step
self.optimizer.step() self.optimizer.step()
@@ -2050,7 +2050,7 @@ class Trainer:
# PyTorch/XLA relies on the data loader to insert the mark_step for # PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually # each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here. # insert the mark_step here.
if is_torch_tpu_available(): if is_torch_xla_available():
xm.mark_step() xm.mark_step()
break break
if step < 0: if step < 0:
@@ -2065,7 +2065,7 @@ class Trainer:
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available(): if is_torch_xla_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
else: else:
@@ -2083,7 +2083,7 @@ class Trainer:
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sure the model has been saved by process 0. # Wait for everyone to get here so we are sure the model has been saved by process 0.
if is_torch_tpu_available(): if is_torch_xla_available():
xm.rendezvous("load_best_model_at_end") xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED: elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier() dist.barrier()
@@ -2402,7 +2402,7 @@ class Trainer:
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged: if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_tpu_available(): if is_torch_xla_available():
xm.mark_step() xm.mark_step()
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
@@ -2415,7 +2415,7 @@ class Trainer:
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None: if grad_norm is not None:
logs["grad_norm"] = grad_norm logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
logs["learning_rate"] = self._get_learning_rate() logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar self._total_loss_scalar += tr_loss_scalar
@@ -2478,7 +2478,7 @@ class Trainer:
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted." "\nThis won't yield the same results as if the training had not been interrupted."
) )
if is_torch_tpu_available(): if is_torch_xla_available():
xm.set_rng_state(checkpoint_rng_state["xla"]) xm.set_rng_state(checkpoint_rng_state["xla"])
if is_torch_npu_available(): if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED: if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
@@ -2556,7 +2556,7 @@ class Trainer:
else: else:
rng_states["cuda"] = torch.cuda.random.get_rng_state() rng_states["cuda"] = torch.cuda.random.get_rng_state()
if is_torch_tpu_available(): if is_torch_xla_available():
rng_states["xla"] = xm.get_rng_state() rng_states["xla"] = xm.get_rng_state()
if is_torch_npu_available(): if is_torch_npu_available():
@@ -2575,7 +2575,7 @@ class Trainer:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
def _save_optimizer_and_scheduler(self, output_dir): def _save_optimizer_and_scheduler(self, output_dir):
if is_torch_tpu_available(): if is_torch_xla_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
@@ -2620,7 +2620,7 @@ class Trainer:
if ( if (
self.args.should_save self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available() and not is_torch_xla_available()
): ):
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
@@ -2657,7 +2657,7 @@ class Trainer:
) )
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
if is_torch_tpu_available(): if is_torch_xla_available():
# On TPU we have to take some extra precautions to properly load the states on the right device. # On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
@@ -2964,7 +2964,7 @@ class Trainer:
if output_dir is None: if output_dir is None:
output_dir = self.args.output_dir output_dir = self.args.output_dir
if is_torch_tpu_available(): if is_torch_xla_available():
self._save_tpu(output_dir) self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes. # Calling the state_dict needs to be done on the wrapped model and on all processes.
@@ -3405,7 +3405,7 @@ class Trainer:
main_input_name = getattr(self.model, "main_input_name", "input_ids") main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
if is_torch_tpu_available(): if is_torch_xla_available():
xm.mark_step() xm.mark_step()
# Update containers on host # Update containers on host
@@ -3529,7 +3529,7 @@ class Trainer:
""" """
if tensors is None: if tensors is None:
return return
if is_torch_tpu_available(): if is_torch_xla_available():
if name is None: if name is None:
name = "nested_gather" name = "nested_gather"
tensors = nested_xla_mesh_reduce(tensors, name) tensors = nested_xla_mesh_reduce(tensors, name)
@@ -4045,7 +4045,7 @@ class Trainer:
""" """
if tensors is None: if tensors is None:
return return
if is_torch_tpu_available(): if is_torch_xla_available():
tensors = nested_xla_mesh_reduce(tensors, name) tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors) tensors = smp_gather(tensors)

View File

@@ -39,13 +39,13 @@ from torch.utils.data.distributed import DistributedSampler
from .integrations.deepspeed import is_deepspeed_zero3_enabled from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging from .utils import is_sagemaker_mp_enabled, is_torch_xla_available, is_training_run_on_sagemaker, logging
if is_training_run_on_sagemaker(): if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout)) logging.add_handler(StreamHandler(sys.stdout))
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 # this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
@@ -179,7 +179,7 @@ def nested_detach(tensors):
def nested_xla_mesh_reduce(tensors, name): def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available(): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):

View File

@@ -37,7 +37,7 @@ from .utils import (
is_torch_cuda_available, is_torch_cuda_available,
is_torch_mps_available, is_torch_mps_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tpu_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
requires_backends, requires_backends,
) )
@@ -340,7 +340,7 @@ def is_main_process(local_rank):
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`. `local_rank`.
""" """
if is_torch_tpu_available(check_device=True): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.get_ordinal() == 0 return xm.get_ordinal() == 0
@@ -351,7 +351,7 @@ def total_processes_number(local_rank):
""" """
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs. Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
""" """
if is_torch_tpu_available(check_device=True): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.xrt_world_size() return xm.xrt_world_size()

View File

@@ -49,7 +49,7 @@ from .utils import (
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
logging, logging,
requires_backends, requires_backends,
@@ -74,7 +74,7 @@ if is_accelerate_available():
from .trainer_pt_utils import AcceleratorConfig from .trainer_pt_utils import AcceleratorConfig
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if is_torch_neuroncore_available(check_device=False): if is_torch_neuroncore_available(check_device=False):
@@ -130,7 +130,9 @@ def get_xla_device_type(device: "torch.device") -> Optional[str]:
""" """
Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device. Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
""" """
if is_torch_tpu_available(): if is_torch_xla_available():
if device.type == "cpu":
return "CPU"
return xm.xla_real_devices([device])[0].split(":")[0] return xm.xla_real_devices([device])[0].split(":")[0]
return None return None
@@ -1475,7 +1477,7 @@ class TrainingArguments:
self.half_precision_backend = self.fp16_backend self.half_precision_backend = self.fp16_backend
if self.bf16 or self.bf16_full_eval: if self.bf16 or self.bf16_full_eval:
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_xla_available():
# cpu # cpu
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
elif not self.use_cpu: elif not self.use_cpu:
@@ -1530,7 +1532,7 @@ class TrainingArguments:
and (self.device.type != "cuda") and (self.device.type != "cuda")
and (self.device.type != "npu") and (self.device.type != "npu")
and (self.device.type != "xpu") and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU") and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (self.fp16 or self.fp16_full_eval) and (self.fp16 or self.fp16_full_eval)
): ):
raise ValueError( raise ValueError(
@@ -1544,7 +1546,7 @@ class TrainingArguments:
and (self.device.type != "cuda") and (self.device.type != "cuda")
and (self.device.type != "npu") and (self.device.type != "npu")
and (self.device.type != "xpu") and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU") and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (get_xla_device_type(self.device) != "TPU") and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu") and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval) and (self.bf16 or self.bf16_full_eval)
@@ -1694,7 +1696,8 @@ class TrainingArguments:
if self.fsdp_config["xla"]: if self.fsdp_config["xla"]:
if len(self.fsdp) > 0: if len(self.fsdp) > 0:
# store XLA fsdp configuration parameters into a dictionary # store XLA fsdp configuration parameters into a dictionary
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}) # Copy the config to avoid modifying the original config (which may be used for JSON serialization)
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy()
# apply appropriate string to torch.dtype conversions for parameters # apply appropriate string to torch.dtype conversions for parameters
if "compute_dtype" in self.xla_fsdp_config: if "compute_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"]) self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
@@ -1948,7 +1951,7 @@ class TrainingArguments:
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
) )
if is_torch_tpu_available(): if is_torch_xla_available():
device = self.distributed_state.device device = self.distributed_state.device
self._n_gpu = 0 self._n_gpu = 0
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
@@ -2029,7 +2032,7 @@ class TrainingArguments:
- `ParallelMode.TPU`: several TPU cores. - `ParallelMode.TPU`: several TPU cores.
""" """
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
if is_torch_tpu_available(): if is_torch_xla_available():
return ParallelMode.TPU return ParallelMode.TPU
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return ParallelMode.SAGEMAKER_MODEL_PARALLEL return ParallelMode.SAGEMAKER_MODEL_PARALLEL
@@ -2180,7 +2183,7 @@ class TrainingArguments:
# tell all replicas to wait # tell all replicas to wait
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
if is_torch_tpu_available(): if is_torch_xla_available():
xm.rendezvous(desc) xm.rendezvous(desc)
else: else:
dist.barrier() dist.barrier()
@@ -2189,7 +2192,7 @@ class TrainingArguments:
if is_main_process: if is_main_process:
# the wait is over # the wait is over
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
if is_torch_tpu_available(): if is_torch_xla_available():
xm.rendezvous(desc) xm.rendezvous(desc)
else: else:
dist.barrier() dist.barrier()

View File

@@ -189,6 +189,7 @@ from .import_utils import (
is_torch_tensorrt_fx_available, is_torch_tensorrt_fx_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdistx_available, is_torchdistx_available,

View File

@@ -62,6 +62,9 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
# Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0.
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
# `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it. # `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it.
@@ -249,6 +252,13 @@ if _torch_available:
) )
_torch_xla_available = False
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
if _torch_xla_available:
logger.info(f"Torch XLA version {_torch_xla_version} available.")
def is_kenlm_available(): def is_kenlm_available():
return _kenlm_available return _kenlm_available
@@ -484,6 +494,12 @@ def is_g2p_en_available():
@lru_cache() @lru_cache()
def is_torch_tpu_available(check_device=True): def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment" "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
warnings.warn(
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
"Please use the `is_torch_xla_available` instead.",
FutureWarning,
)
if not _torch_available: if not _torch_available:
return False return False
if importlib.util.find_spec("torch_xla") is not None: if importlib.util.find_spec("torch_xla") is not None:
@@ -500,10 +516,31 @@ def is_torch_tpu_available(check_device=True):
return False return False
@lru_cache
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
"""
Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
the USE_TORCH_XLA to false.
"""
assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
if not _torch_xla_available:
return False
import torch_xla
if check_is_gpu:
return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
elif check_is_tpu:
return torch_xla.runtime.device_type() == "TPU"
return True
@lru_cache() @lru_cache()
def is_torch_neuroncore_available(check_device=True): def is_torch_neuroncore_available(check_device=True):
if importlib.util.find_spec("torch_neuronx") is not None: if importlib.util.find_spec("torch_neuronx") is not None:
return is_torch_tpu_available(check_device) return is_torch_xla_available()
return False return False