Make torch xla available on GPU (#29334)

* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------

Co-authored-by: anw90 <ang868@gmail.com>
Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
Yitong Huang
2024-03-11 22:07:16 +08:00
committed by GitHub
parent 9a3f4d4daf
commit 873d9bb3cc
25 changed files with 120 additions and 77 deletions

View File

@@ -32,7 +32,7 @@ from transformers.optimization import (
)
from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available
from transformers.utils import is_torch_xla_available
logger = logging.get_logger(__name__)
@@ -135,7 +135,7 @@ class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
elif is_torch_xla_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler: