Make torch xla available on GPU (#29334)
* add USE_TORCH_XLA env * rename torch_tpu to torch_xla * better is_torch_xla_available; fix some fsdp and performance issues * fix format * fix bug when pjrt_device is cpu * fix bug * fix the deprecation handling --------- Co-authored-by: anw90 <ang868@gmail.com> Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
@@ -115,7 +115,7 @@ from .utils import (
|
||||
is_torch_sdpa_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdynamo_available,
|
||||
@@ -733,11 +733,11 @@ def require_torch_up_to_2_accelerators(test_case):
|
||||
(test_case)
|
||||
|
||||
|
||||
def require_torch_tpu(test_case):
|
||||
def require_torch_xla(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
Decorator marking a test that requires TorchXLA (in PyTorch).
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
|
||||
return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
|
||||
|
||||
|
||||
def require_torch_neuroncore(test_case):
|
||||
|
||||
Reference in New Issue
Block a user