Make torch xla available on GPU (#29334)

* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------

Co-authored-by: anw90 <ang868@gmail.com>
Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
Yitong Huang
2024-03-11 22:07:16 +08:00
committed by GitHub
parent 9a3f4d4daf
commit 873d9bb3cc
25 changed files with 120 additions and 77 deletions

View File

@@ -115,7 +115,7 @@ from .utils import (
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdynamo_available,
@@ -733,11 +733,11 @@ def require_torch_up_to_2_accelerators(test_case):
(test_case)
def require_torch_tpu(test_case):
def require_torch_xla(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
Decorator marking a test that requires TorchXLA (in PyTorch).
"""
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
def require_torch_neuroncore(test_case):