Patch with accelerate xpu (#25714)

* patch with accelerate xpu

* patch with accelerate xpu

* formatting

* fix tests

* revert ruff unrelated fixes

* revert ruff unrelated fixes

* revert ruff unrelated fixes

* fix test

* review fixes

* review fixes

* black fixed

* review commits

* review commits

* style fix

* use pytorch_utils

* revert markuplm test
This commit is contained in:
Abhilash Majumder
2023-09-05 20:11:42 +05:30
committed by GitHub
parent aa5c94d38d
commit 70a98024b1
7 changed files with 127 additions and 14 deletions

View File

@@ -100,6 +100,7 @@ from .utils import (
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdynamo_available,
is_torchvision_available,
@@ -624,6 +625,29 @@ def require_torch_multi_npu(test_case):
return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
def require_torch_xpu(test_case):
"""
Decorator marking a test that requires XPU and IPEX.
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
version.
"""
return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)
def require_torch_multi_xpu(test_case):
"""
Decorator marking a test that requires a multi-XPU setup with IPEX and atleast one XPU device. These tests are
skipped on a machine without IPEX or multiple XPUs.
To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
"""
if not is_torch_xpu_available():
return unittest.skip("test requires IPEX and atleast one XPU device")(test_case)
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
@@ -641,6 +665,8 @@ if is_torch_available():
torch_device = "cuda"
elif _run_third_party_device_tests and is_torch_npu_available():
torch_device = "npu"
elif _run_third_party_device_tests and is_torch_xpu_available():
torch_device = "xpu"
else:
torch_device = "cpu"