Make torch xla available on GPU (#29334)

* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------

Co-authored-by: anw90 <ang868@gmail.com>
Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
Yitong Huang
2024-03-11 22:07:16 +08:00
committed by GitHub
parent 9a3f4d4daf
commit 873d9bb3cc
25 changed files with 120 additions and 77 deletions

View File

@@ -21,11 +21,11 @@ from typing import Dict, List, Optional
from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers import Seq2SeqTrainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met