Device agnostic trainer testing (#27131)

This commit is contained in:
Hz, Ji
2023-10-31 02:16:40 +08:00
committed by GitHub
parent 84724efd10
commit 5bbf671276
3 changed files with 87 additions and 46 deletions

View File

@@ -629,6 +629,20 @@ def require_torch_multi_gpu(test_case):
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
def require_torch_multi_accelerator(test_case):
"""
Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
test_case
)
def require_torch_non_multi_gpu(test_case):
"""
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
@@ -641,6 +655,16 @@ def require_torch_non_multi_gpu(test_case):
return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
def require_torch_non_multi_accelerator(test_case):
"""
Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
def require_torch_up_to_2_gpus(test_case):
"""
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
@@ -653,6 +677,17 @@ def require_torch_up_to_2_gpus(test_case):
return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
def require_torch_up_to_2_accelerators(test_case):
"""
Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")
(test_case)
def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
@@ -774,7 +809,9 @@ def require_torch_gpu(test_case):
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accessible accelerator and PyTorch."""
return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case)
return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
test_case
)
def require_torch_fp16(test_case):