Device agnostic trainer testing (#27131)
This commit is contained in:
@@ -629,6 +629,20 @@ def require_torch_multi_gpu(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
|
||||
without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
|
||||
multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_non_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
||||
@@ -641,6 +655,16 @@ def require_torch_non_multi_gpu(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
|
||||
|
||||
|
||||
def require_torch_non_multi_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
|
||||
|
||||
|
||||
def require_torch_up_to_2_gpus(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
|
||||
@@ -653,6 +677,17 @@ def require_torch_up_to_2_gpus(test_case):
|
||||
return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_up_to_2_accelerators(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")
|
||||
(test_case)
|
||||
|
||||
|
||||
def require_torch_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
@@ -774,7 +809,9 @@ def require_torch_gpu(test_case):
|
||||
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accessible accelerator and PyTorch."""
|
||||
return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case)
|
||||
return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_fp16(test_case):
|
||||
|
||||
Reference in New Issue
Block a user