Make torch xla available on GPU (#29334)
* add USE_TORCH_XLA env * rename torch_tpu to torch_xla * better is_torch_xla_available; fix some fsdp and performance issues * fix format * fix bug when pjrt_device is cpu * fix bug * fix the deprecation handling --------- Co-authored-by: anw90 <ang868@gmail.com> Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
@@ -46,7 +46,7 @@ from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xla_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.testing_utils import CaptureLogger
|
||||
@@ -602,9 +602,9 @@ def main():
|
||||
tokenizer=tokenizer,
|
||||
# Data collator will default to DataCollatorWithPadding, so we change it.
|
||||
data_collator=default_data_collator,
|
||||
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
||||
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
||||
if training_args.do_eval and not is_torch_tpu_available()
|
||||
if training_args.do_eval and not is_torch_xla_available()
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xla_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
@@ -620,9 +620,9 @@ def main():
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
||||
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
||||
if training_args.do_eval and not is_torch_tpu_available()
|
||||
if training_args.do_eval and not is_torch_xla_available()
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import sys
|
||||
from time import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_tpu
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_xla
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -44,7 +44,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
@require_torch_tpu
|
||||
@require_torch_xla
|
||||
class TorchXLAExamplesTests(TestCasePlus):
|
||||
def test_run_glue(self):
|
||||
import xla_spawn
|
||||
|
||||
@@ -18,11 +18,11 @@ A subclass of `Trainer` specific to Question-Answering tasks
|
||||
import math
|
||||
import time
|
||||
|
||||
from transformers import Trainer, is_torch_tpu_available
|
||||
from transformers import Trainer, is_torch_xla_available
|
||||
from transformers.trainer_utils import PredictionOutput, speed_metrics
|
||||
|
||||
|
||||
if is_torch_tpu_available(check_device=False):
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
|
||||
@@ -21,11 +21,11 @@ from typing import Dict, List, Optional
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import Seq2SeqTrainer, is_torch_tpu_available
|
||||
from transformers import Seq2SeqTrainer, is_torch_xla_available
|
||||
from transformers.trainer_utils import PredictionOutput, speed_metrics
|
||||
|
||||
|
||||
if is_torch_tpu_available(check_device=False):
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
|
||||
Reference in New Issue
Block a user