Make torch xla available on GPU (#29334)
* add USE_TORCH_XLA env * rename torch_tpu to torch_xla * better is_torch_xla_available; fix some fsdp and performance issues * fix format * fix bug when pjrt_device is cpu * fix bug * fix the deprecation handling --------- Co-authored-by: anw90 <ang868@gmail.com> Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
@@ -452,7 +452,7 @@ Dekorateure werden verwendet, um die Anforderungen von Tests in Bezug auf CPU/GP
|
|||||||
- `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich
|
- `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich
|
||||||
- `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs
|
- `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs
|
||||||
- `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs
|
- `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs
|
||||||
- `require_torch_tpu` - wie `require_torch` plus erfordert mindestens 1 TPU
|
- `require_torch_xla` - wie `require_torch` plus erfordert mindestens 1 TPU
|
||||||
|
|
||||||
Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen:
|
Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen:
|
||||||
|
|
||||||
|
|||||||
@@ -451,7 +451,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
|
|||||||
- `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs
|
- `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs
|
||||||
- `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs
|
- `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs
|
||||||
- `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs
|
- `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs
|
||||||
- `require_torch_tpu` - as `require_torch` plus requires at least 1 TPU
|
- `require_torch_xla` - as `require_torch` plus requires at least 1 TPU
|
||||||
|
|
||||||
Let's depict the GPU requirements in the following table:
|
Let's depict the GPU requirements in the following table:
|
||||||
|
|
||||||
|
|||||||
@@ -424,7 +424,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
|
|||||||
- `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。
|
- `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。
|
||||||
- `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。
|
- `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。
|
||||||
- `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。
|
- `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。
|
||||||
- `require_torch_tpu` - `require_torch` に加えて、少なくとも1つのTPUが必要です。
|
- `require_torch_xla` - `require_torch` に加えて、少なくとも1つのTPUが必要です。
|
||||||
|
|
||||||
以下の表にGPUの要件を示します:
|
以下の表にGPUの要件を示します:
|
||||||
|
|
||||||
|
|||||||
@@ -452,7 +452,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
|
|||||||
- `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다.
|
- `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다.
|
||||||
- `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다.
|
- `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다.
|
||||||
- `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다.
|
- `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다.
|
||||||
- `require_torch_tpu` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.
|
- `require_torch_xla` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.
|
||||||
|
|
||||||
GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ:
|
GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ:
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from transformers.optimization import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_pt_utils import get_tpu_sampler
|
from transformers.trainer_pt_utils import get_tpu_sampler
|
||||||
from transformers.training_args import ParallelMode
|
from transformers.training_args import ParallelMode
|
||||||
from transformers.utils import is_torch_tpu_available
|
from transformers.utils import is_torch_xla_available
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -135,7 +135,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
return None
|
return None
|
||||||
elif is_torch_tpu_available():
|
elif is_torch_xla_available():
|
||||||
return get_tpu_sampler(self.train_dataset)
|
return get_tpu_sampler(self.train_dataset)
|
||||||
else:
|
else:
|
||||||
if self.args.sortish_sampler:
|
if self.args.sortish_sampler:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from transformers import (
|
|||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
default_data_collator,
|
default_data_collator,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import CaptureLogger
|
from transformers.testing_utils import CaptureLogger
|
||||||
@@ -602,9 +602,9 @@ def main():
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
# Data collator will default to DataCollatorWithPadding, so we change it.
|
# Data collator will default to DataCollatorWithPadding, so we change it.
|
||||||
data_collator=default_data_collator,
|
data_collator=default_data_collator,
|
||||||
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
|
||||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
||||||
if training_args.do_eval and not is_torch_tpu_available()
|
if training_args.do_eval and not is_torch_xla_available()
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from transformers import (
|
|||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
@@ -620,9 +620,9 @@ def main():
|
|||||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
|
||||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
||||||
if training_args.do_eval and not is_torch_tpu_available()
|
if training_args.do_eval and not is_torch_xla_available()
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import sys
|
|||||||
from time import time
|
from time import time
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers.testing_utils import TestCasePlus, require_torch_tpu
|
from transformers.testing_utils import TestCasePlus, require_torch_xla
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -44,7 +44,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
|
|||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_tpu
|
@require_torch_xla
|
||||||
class TorchXLAExamplesTests(TestCasePlus):
|
class TorchXLAExamplesTests(TestCasePlus):
|
||||||
def test_run_glue(self):
|
def test_run_glue(self):
|
||||||
import xla_spawn
|
import xla_spawn
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ A subclass of `Trainer` specific to Question-Answering tasks
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from transformers import Trainer, is_torch_tpu_available
|
from transformers import Trainer, is_torch_xla_available
|
||||||
from transformers.trainer_utils import PredictionOutput, speed_metrics
|
from transformers.trainer_utils import PredictionOutput, speed_metrics
|
||||||
|
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
|
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainer, is_torch_tpu_available
|
from transformers import Seq2SeqTrainer, is_torch_xla_available
|
||||||
from transformers.trainer_utils import PredictionOutput, speed_metrics
|
from transformers.trainer_utils import PredictionOutput, speed_metrics
|
||||||
|
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
|
|
||||||
|
|||||||
@@ -24,13 +24,13 @@ import quant_trainer
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from transformers import Trainer, is_torch_tpu_available
|
from transformers import Trainer, is_torch_xla_available
|
||||||
from transformers.trainer_utils import PredictionOutput
|
from transformers.trainer_utils import PredictionOutput
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
|
|
||||||
|
|||||||
@@ -1093,6 +1093,7 @@ _import_structure = {
|
|||||||
"is_torch_npu_available",
|
"is_torch_npu_available",
|
||||||
"is_torch_tpu_available",
|
"is_torch_tpu_available",
|
||||||
"is_torchvision_available",
|
"is_torchvision_available",
|
||||||
|
"is_torch_xla_available",
|
||||||
"is_torch_xpu_available",
|
"is_torch_xpu_available",
|
||||||
"is_vision_available",
|
"is_vision_available",
|
||||||
"logging",
|
"logging",
|
||||||
@@ -5897,6 +5898,7 @@ if TYPE_CHECKING:
|
|||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
is_torchvision_available,
|
is_torchvision_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from typing import Tuple
|
|||||||
from ..utils import (
|
from ..utils import (
|
||||||
cached_property,
|
cached_property,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
logging,
|
logging,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
@@ -31,7 +31,7 @@ from .benchmark_args_utils import BenchmarkArguments
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
|
|||||||
if not self.cuda:
|
if not self.cuda:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
n_gpu = 0
|
n_gpu = 0
|
||||||
elif is_torch_tpu_available():
|
elif is_torch_xla_available():
|
||||||
device = xm.xla_device()
|
device = xm.xla_device()
|
||||||
n_gpu = 0
|
n_gpu = 0
|
||||||
elif is_torch_xpu_available():
|
elif is_torch_xpu_available():
|
||||||
@@ -101,7 +101,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_tpu(self):
|
def is_tpu(self):
|
||||||
return is_torch_tpu_available() and self.tpu
|
return is_torch_xla_available() and self.tpu
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_idx(self) -> int:
|
def device_idx(self) -> int:
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ from .utils import (
|
|||||||
is_torch_fx_proxy,
|
is_torch_fx_proxy,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
is_training_run_on_sagemaker,
|
is_training_run_on_sagemaker,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ if TYPE_CHECKING and _has_neptune:
|
|||||||
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
|
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
|
||||||
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
|
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
|
||||||
from ..training_args import ParallelMode # noqa: E402
|
from ..training_args import ParallelMode # noqa: E402
|
||||||
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
|
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
# Integration functions:
|
# Integration functions:
|
||||||
@@ -752,7 +752,7 @@ class WandbCallback(TrainerCallback):
|
|||||||
|
|
||||||
# keep track of model topology and gradients, unsupported on TPU
|
# keep track of model topology and gradients, unsupported on TPU
|
||||||
_watch_model = os.getenv("WANDB_WATCH", "false")
|
_watch_model = os.getenv("WANDB_WATCH", "false")
|
||||||
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
|
if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
|
||||||
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
|
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
|
||||||
self._wandb.run._label(code="transformers_trainer")
|
self._wandb.run._label(code="transformers_trainer")
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,11 @@
|
|||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from ..utils import is_torch_tpu_available
|
from ..utils import is_torch_xla_available
|
||||||
|
|
||||||
|
|
||||||
def tpu_spmd_dataloader(dataloader: DataLoader):
|
def tpu_spmd_dataloader(dataloader: DataLoader):
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
import torch_xla.distributed.parallel_loader as pl
|
import torch_xla.distributed.parallel_loader as pl
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ from .utils import (
|
|||||||
is_remote_url,
|
is_remote_url,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
strtobool,
|
strtobool,
|
||||||
@@ -246,10 +246,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
|
|||||||
# Adding fix for https://github.com/pytorch/xla/issues/4152
|
# Adding fix for https://github.com/pytorch/xla/issues/4152
|
||||||
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
|
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
|
||||||
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
|
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
|
||||||
# NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo
|
# NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
|
||||||
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
|
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
|
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
|
||||||
if t.dtype == torch.float:
|
if t.dtype == torch.float:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if t.dtype == torch.double:
|
if t.dtype == torch.double:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from packaging import version
|
|||||||
from safetensors.torch import storage_ptr, storage_size
|
from safetensors.torch import storage_ptr, storage_size
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .utils import is_torch_tpu_available, logging
|
from .utils import is_torch_xla_available, logging
|
||||||
|
|
||||||
|
|
||||||
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
||||||
@@ -282,7 +282,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
|
|||||||
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
||||||
non-overlapping lifetimes may have the same id.
|
non-overlapping lifetimes may have the same id.
|
||||||
"""
|
"""
|
||||||
if tensor.device.type == "xla" and is_torch_tpu_available():
|
if tensor.device.type == "xla" and is_torch_xla_available():
|
||||||
# NOTE: xla tensors dont have storage
|
# NOTE: xla tensors dont have storage
|
||||||
# use some other unique id to distinguish.
|
# use some other unique id to distinguish.
|
||||||
# this is a XLA tensor, it must be created using torch_xla's
|
# this is a XLA tensor, it must be created using torch_xla's
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ from .utils import (
|
|||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
is_torch_tensorrt_fx_available,
|
is_torch_tensorrt_fx_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
is_torchdynamo_available,
|
is_torchdynamo_available,
|
||||||
@@ -733,11 +733,11 @@ def require_torch_up_to_2_accelerators(test_case):
|
|||||||
(test_case)
|
(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_tpu(test_case):
|
def require_torch_xla(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires a TPU (in PyTorch).
|
Decorator marking a test that requires TorchXLA (in PyTorch).
|
||||||
"""
|
"""
|
||||||
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
|
return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_neuroncore(test_case):
|
def require_torch_neuroncore(test_case):
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ from .utils import (
|
|||||||
is_torch_compile_available,
|
is_torch_compile_available,
|
||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
logging,
|
logging,
|
||||||
strtobool,
|
strtobool,
|
||||||
)
|
)
|
||||||
@@ -170,7 +170,7 @@ if is_apex_available():
|
|||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
import torch_xla.distributed.spmd as xs
|
import torch_xla.distributed.spmd as xs
|
||||||
@@ -508,7 +508,7 @@ class Trainer:
|
|||||||
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
|
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
|
||||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
)
|
)
|
||||||
if is_torch_tpu_available() and self.optimizer is not None:
|
if is_torch_xla_available() and self.optimizer is not None:
|
||||||
for param in self.model.parameters():
|
for param in self.model.parameters():
|
||||||
model_device = param.device
|
model_device = param.device
|
||||||
break
|
break
|
||||||
@@ -856,7 +856,7 @@ class Trainer:
|
|||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
||||||
# Deprecated code
|
# Deprecated code
|
||||||
if self.args.use_legacy_prediction_loop:
|
if self.args.use_legacy_prediction_loop:
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
return SequentialDistributedSampler(
|
return SequentialDistributedSampler(
|
||||||
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||||
)
|
)
|
||||||
@@ -1975,7 +1975,7 @@ class Trainer:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
args.logging_nan_inf_filter
|
args.logging_nan_inf_filter
|
||||||
and not is_torch_tpu_available()
|
and not is_torch_xla_available()
|
||||||
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
|
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
|
||||||
):
|
):
|
||||||
# if loss is nan or inf simply add the average of previous logged losses
|
# if loss is nan or inf simply add the average of previous logged losses
|
||||||
@@ -2027,7 +2027,7 @@ class Trainer:
|
|||||||
if hasattr(grad_norm, "item"):
|
if hasattr(grad_norm, "item"):
|
||||||
grad_norm = grad_norm.item()
|
grad_norm = grad_norm.item()
|
||||||
else:
|
else:
|
||||||
grad_norm = _grad_norm.item() if _grad_norm is not None else None
|
grad_norm = _grad_norm
|
||||||
|
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
@@ -2050,7 +2050,7 @@ class Trainer:
|
|||||||
# PyTorch/XLA relies on the data loader to insert the mark_step for
|
# PyTorch/XLA relies on the data loader to insert the mark_step for
|
||||||
# each step. Since we are breaking the loop early, we need to manually
|
# each step. Since we are breaking the loop early, we need to manually
|
||||||
# insert the mark_step here.
|
# insert the mark_step here.
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
break
|
break
|
||||||
if step < 0:
|
if step < 0:
|
||||||
@@ -2065,7 +2065,7 @@ class Trainer:
|
|||||||
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
|
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
|
||||||
|
|
||||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||||
xm.master_print(met.metrics_report())
|
xm.master_print(met.metrics_report())
|
||||||
else:
|
else:
|
||||||
@@ -2083,7 +2083,7 @@ class Trainer:
|
|||||||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||||
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
|
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
|
||||||
# Wait for everyone to get here so we are sure the model has been saved by process 0.
|
# Wait for everyone to get here so we are sure the model has been saved by process 0.
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
xm.rendezvous("load_best_model_at_end")
|
xm.rendezvous("load_best_model_at_end")
|
||||||
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
@@ -2402,7 +2402,7 @@ class Trainer:
|
|||||||
|
|
||||||
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
|
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
|
||||||
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
|
|
||||||
logs: Dict[str, float] = {}
|
logs: Dict[str, float] = {}
|
||||||
@@ -2415,7 +2415,7 @@ class Trainer:
|
|||||||
|
|
||||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||||
if grad_norm is not None:
|
if grad_norm is not None:
|
||||||
logs["grad_norm"] = grad_norm
|
logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
|
||||||
logs["learning_rate"] = self._get_learning_rate()
|
logs["learning_rate"] = self._get_learning_rate()
|
||||||
|
|
||||||
self._total_loss_scalar += tr_loss_scalar
|
self._total_loss_scalar += tr_loss_scalar
|
||||||
@@ -2478,7 +2478,7 @@ class Trainer:
|
|||||||
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
|
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
|
||||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||||
)
|
)
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
xm.set_rng_state(checkpoint_rng_state["xla"])
|
xm.set_rng_state(checkpoint_rng_state["xla"])
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
@@ -2556,7 +2556,7 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
rng_states["cuda"] = torch.cuda.random.get_rng_state()
|
rng_states["cuda"] = torch.cuda.random.get_rng_state()
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
rng_states["xla"] = xm.get_rng_state()
|
rng_states["xla"] = xm.get_rng_state()
|
||||||
|
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
@@ -2575,7 +2575,7 @@ class Trainer:
|
|||||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
|
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
|
||||||
|
|
||||||
def _save_optimizer_and_scheduler(self, output_dir):
|
def _save_optimizer_and_scheduler(self, output_dir):
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
xm.rendezvous("saving_optimizer_states")
|
xm.rendezvous("saving_optimizer_states")
|
||||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
@@ -2620,7 +2620,7 @@ class Trainer:
|
|||||||
if (
|
if (
|
||||||
self.args.should_save
|
self.args.should_save
|
||||||
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
|
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
|
||||||
and not is_torch_tpu_available()
|
and not is_torch_xla_available()
|
||||||
):
|
):
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
@@ -2657,7 +2657,7 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
|
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||||
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
@@ -2964,7 +2964,7 @@ class Trainer:
|
|||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = self.args.output_dir
|
output_dir = self.args.output_dir
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
self._save_tpu(output_dir)
|
self._save_tpu(output_dir)
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
# Calling the state_dict needs to be done on the wrapped model and on all processes.
|
# Calling the state_dict needs to be done on the wrapped model and on all processes.
|
||||||
@@ -3405,7 +3405,7 @@ class Trainer:
|
|||||||
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
||||||
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
|
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
|
|
||||||
# Update containers on host
|
# Update containers on host
|
||||||
@@ -3529,7 +3529,7 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
if tensors is None:
|
if tensors is None:
|
||||||
return
|
return
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
if name is None:
|
if name is None:
|
||||||
name = "nested_gather"
|
name = "nested_gather"
|
||||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||||
@@ -4045,7 +4045,7 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
if tensors is None:
|
if tensors is None:
|
||||||
return
|
return
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
tensors = smp_gather(tensors)
|
tensors = smp_gather(tensors)
|
||||||
|
|||||||
@@ -39,13 +39,13 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
|
|
||||||
from .integrations.deepspeed import is_deepspeed_zero3_enabled
|
from .integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from .tokenization_utils_base import BatchEncoding
|
from .tokenization_utils_base import BatchEncoding
|
||||||
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
|
from .utils import is_sagemaker_mp_enabled, is_torch_xla_available, is_training_run_on_sagemaker, logging
|
||||||
|
|
||||||
|
|
||||||
if is_training_run_on_sagemaker():
|
if is_training_run_on_sagemaker():
|
||||||
logging.add_handler(StreamHandler(sys.stdout))
|
logging.add_handler(StreamHandler(sys.stdout))
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
|
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
|
||||||
@@ -179,7 +179,7 @@ def nested_detach(tensors):
|
|||||||
|
|
||||||
|
|
||||||
def nested_xla_mesh_reduce(tensors, name):
|
def nested_xla_mesh_reduce(tensors, name):
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
if isinstance(tensors, (list, tuple)):
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from .utils import (
|
|||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
@@ -340,7 +340,7 @@ def is_main_process(local_rank):
|
|||||||
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
|
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
|
||||||
`local_rank`.
|
`local_rank`.
|
||||||
"""
|
"""
|
||||||
if is_torch_tpu_available(check_device=True):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
return xm.get_ordinal() == 0
|
return xm.get_ordinal() == 0
|
||||||
@@ -351,7 +351,7 @@ def total_processes_number(local_rank):
|
|||||||
"""
|
"""
|
||||||
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
|
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
|
||||||
"""
|
"""
|
||||||
if is_torch_tpu_available(check_device=True):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
return xm.xrt_world_size()
|
return xm.xrt_world_size()
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ from .utils import (
|
|||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
logging,
|
logging,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
@@ -74,7 +74,7 @@ if is_accelerate_available():
|
|||||||
|
|
||||||
from .trainer_pt_utils import AcceleratorConfig
|
from .trainer_pt_utils import AcceleratorConfig
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
if is_torch_neuroncore_available(check_device=False):
|
if is_torch_neuroncore_available(check_device=False):
|
||||||
@@ -130,7 +130,9 @@ def get_xla_device_type(device: "torch.device") -> Optional[str]:
|
|||||||
"""
|
"""
|
||||||
Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
|
Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
|
||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
|
if device.type == "cpu":
|
||||||
|
return "CPU"
|
||||||
return xm.xla_real_devices([device])[0].split(":")[0]
|
return xm.xla_real_devices([device])[0].split(":")[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1475,7 +1477,7 @@ class TrainingArguments:
|
|||||||
self.half_precision_backend = self.fp16_backend
|
self.half_precision_backend = self.fp16_backend
|
||||||
|
|
||||||
if self.bf16 or self.bf16_full_eval:
|
if self.bf16 or self.bf16_full_eval:
|
||||||
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
|
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_xla_available():
|
||||||
# cpu
|
# cpu
|
||||||
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
|
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
|
||||||
elif not self.use_cpu:
|
elif not self.use_cpu:
|
||||||
@@ -1530,7 +1532,7 @@ class TrainingArguments:
|
|||||||
and (self.device.type != "cuda")
|
and (self.device.type != "cuda")
|
||||||
and (self.device.type != "npu")
|
and (self.device.type != "npu")
|
||||||
and (self.device.type != "xpu")
|
and (self.device.type != "xpu")
|
||||||
and (get_xla_device_type(self.device) != "GPU")
|
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
|
||||||
and (self.fp16 or self.fp16_full_eval)
|
and (self.fp16 or self.fp16_full_eval)
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1544,7 +1546,7 @@ class TrainingArguments:
|
|||||||
and (self.device.type != "cuda")
|
and (self.device.type != "cuda")
|
||||||
and (self.device.type != "npu")
|
and (self.device.type != "npu")
|
||||||
and (self.device.type != "xpu")
|
and (self.device.type != "xpu")
|
||||||
and (get_xla_device_type(self.device) != "GPU")
|
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
|
||||||
and (get_xla_device_type(self.device) != "TPU")
|
and (get_xla_device_type(self.device) != "TPU")
|
||||||
and (self.device.type != "cpu")
|
and (self.device.type != "cpu")
|
||||||
and (self.bf16 or self.bf16_full_eval)
|
and (self.bf16 or self.bf16_full_eval)
|
||||||
@@ -1694,7 +1696,8 @@ class TrainingArguments:
|
|||||||
if self.fsdp_config["xla"]:
|
if self.fsdp_config["xla"]:
|
||||||
if len(self.fsdp) > 0:
|
if len(self.fsdp) > 0:
|
||||||
# store XLA fsdp configuration parameters into a dictionary
|
# store XLA fsdp configuration parameters into a dictionary
|
||||||
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {})
|
# Copy the config to avoid modifying the original config (which may be used for JSON serialization)
|
||||||
|
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy()
|
||||||
# apply appropriate string to torch.dtype conversions for parameters
|
# apply appropriate string to torch.dtype conversions for parameters
|
||||||
if "compute_dtype" in self.xla_fsdp_config:
|
if "compute_dtype" in self.xla_fsdp_config:
|
||||||
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
|
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
|
||||||
@@ -1948,7 +1951,7 @@ class TrainingArguments:
|
|||||||
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
|
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
|
||||||
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
|
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
|
||||||
)
|
)
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
device = self.distributed_state.device
|
device = self.distributed_state.device
|
||||||
self._n_gpu = 0
|
self._n_gpu = 0
|
||||||
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
|
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
|
||||||
@@ -2029,7 +2032,7 @@ class TrainingArguments:
|
|||||||
- `ParallelMode.TPU`: several TPU cores.
|
- `ParallelMode.TPU`: several TPU cores.
|
||||||
"""
|
"""
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
return ParallelMode.TPU
|
return ParallelMode.TPU
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
|
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
|
||||||
@@ -2180,7 +2183,7 @@ class TrainingArguments:
|
|||||||
# tell all replicas to wait
|
# tell all replicas to wait
|
||||||
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
xm.rendezvous(desc)
|
xm.rendezvous(desc)
|
||||||
else:
|
else:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
@@ -2189,7 +2192,7 @@ class TrainingArguments:
|
|||||||
if is_main_process:
|
if is_main_process:
|
||||||
# the wait is over
|
# the wait is over
|
||||||
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
|
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
|
||||||
if is_torch_tpu_available():
|
if is_torch_xla_available():
|
||||||
xm.rendezvous(desc)
|
xm.rendezvous(desc)
|
||||||
else:
|
else:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ from .import_utils import (
|
|||||||
is_torch_tensorrt_fx_available,
|
is_torch_tensorrt_fx_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
is_torchdistx_available,
|
is_torchdistx_available,
|
||||||
|
|||||||
@@ -62,6 +62,9 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|||||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||||
|
|
||||||
|
# Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0.
|
||||||
|
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
|
||||||
|
|
||||||
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
||||||
|
|
||||||
# `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it.
|
# `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it.
|
||||||
@@ -249,6 +252,13 @@ if _torch_available:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_torch_xla_available = False
|
||||||
|
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
|
||||||
|
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
|
||||||
|
if _torch_xla_available:
|
||||||
|
logger.info(f"Torch XLA version {_torch_xla_version} available.")
|
||||||
|
|
||||||
|
|
||||||
def is_kenlm_available():
|
def is_kenlm_available():
|
||||||
return _kenlm_available
|
return _kenlm_available
|
||||||
|
|
||||||
@@ -484,6 +494,12 @@ def is_g2p_en_available():
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def is_torch_tpu_available(check_device=True):
|
def is_torch_tpu_available(check_device=True):
|
||||||
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
|
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
|
||||||
|
warnings.warn(
|
||||||
|
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
|
||||||
|
"Please use the `is_torch_xla_available` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
if not _torch_available:
|
if not _torch_available:
|
||||||
return False
|
return False
|
||||||
if importlib.util.find_spec("torch_xla") is not None:
|
if importlib.util.find_spec("torch_xla") is not None:
|
||||||
@@ -500,10 +516,31 @@ def is_torch_tpu_available(check_device=True):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
|
||||||
|
"""
|
||||||
|
Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
|
||||||
|
the USE_TORCH_XLA to false.
|
||||||
|
"""
|
||||||
|
assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
|
||||||
|
|
||||||
|
if not _torch_xla_available:
|
||||||
|
return False
|
||||||
|
|
||||||
|
import torch_xla
|
||||||
|
|
||||||
|
if check_is_gpu:
|
||||||
|
return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
|
||||||
|
elif check_is_tpu:
|
||||||
|
return torch_xla.runtime.device_type() == "TPU"
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def is_torch_neuroncore_available(check_device=True):
|
def is_torch_neuroncore_available(check_device=True):
|
||||||
if importlib.util.find_spec("torch_neuronx") is not None:
|
if importlib.util.find_spec("torch_neuronx") is not None:
|
||||||
return is_torch_tpu_available(check_device)
|
return is_torch_xla_available()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user