Patch with accelerate xpu (#25714)
* patch with accelerate xpu * patch with accelerate xpu * formatting * fix tests * revert ruff unrelated fixes * revert ruff unrelated fixes * revert ruff unrelated fixes * fix test * review fixes * review fixes * black fixed * review commits * review commits * style fix * use pytorch_utils * revert markuplm test
This commit is contained in:
committed by
GitHub
parent
aa5c94d38d
commit
70a98024b1
@@ -100,6 +100,7 @@ from .utils import (
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdynamo_available,
|
||||
is_torchvision_available,
|
||||
@@ -624,6 +625,29 @@ def require_torch_multi_npu(test_case):
|
||||
return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_xpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires XPU and IPEX.
|
||||
|
||||
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
|
||||
version.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_xpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-XPU setup with IPEX and atleast one XPU device. These tests are
|
||||
skipped on a machine without IPEX or multiple XPUs.
|
||||
|
||||
To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
|
||||
"""
|
||||
if not is_torch_xpu_available():
|
||||
return unittest.skip("test requires IPEX and atleast one XPU device")(test_case)
|
||||
|
||||
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||
import torch
|
||||
@@ -641,6 +665,8 @@ if is_torch_available():
|
||||
torch_device = "cuda"
|
||||
elif _run_third_party_device_tests and is_torch_npu_available():
|
||||
torch_device = "npu"
|
||||
elif _run_third_party_device_tests and is_torch_xpu_available():
|
||||
torch_device = "xpu"
|
||||
else:
|
||||
torch_device = "cpu"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user