Make torch xla available on GPU (#29334)

* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------

Co-authored-by: anw90 <ang868@gmail.com>
Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
Yitong Huang
2024-03-11 22:07:16 +08:00
committed by GitHub
parent 9a3f4d4daf
commit 873d9bb3cc
25 changed files with 120 additions and 77 deletions

View File

@@ -46,7 +46,7 @@ from transformers import (
Trainer,
TrainingArguments,
default_data_collator,
is_torch_tpu_available,
is_torch_xla_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger
@@ -602,9 +602,9 @@ def main():
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
if training_args.do_eval and not is_torch_xla_available()
else None,
)

View File

@@ -45,7 +45,7 @@ from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
is_torch_tpu_available,
is_torch_xla_available,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
@@ -620,9 +620,9 @@ def main():
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
if training_args.do_eval and not is_torch_xla_available()
else None,
)

View File

@@ -21,7 +21,7 @@ import sys
from time import time
from unittest.mock import patch
from transformers.testing_utils import TestCasePlus, require_torch_tpu
from transformers.testing_utils import TestCasePlus, require_torch_xla
logging.basicConfig(level=logging.DEBUG)
@@ -44,7 +44,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@require_torch_tpu
@require_torch_xla
class TorchXLAExamplesTests(TestCasePlus):
def test_run_glue(self):
import xla_spawn

View File

@@ -18,11 +18,11 @@ A subclass of `Trainer` specific to Question-Answering tasks
import math
import time
from transformers import Trainer, is_torch_tpu_available
from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

View File

@@ -21,11 +21,11 @@ from typing import Dict, List, Optional
from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers import Seq2SeqTrainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met