[tests] multiple improvements (#12294)
* [tests] multiple improvements * cleanup * style * todo to investigate * fix
This commit is contained in:
@@ -383,6 +383,21 @@ def require_torch_non_multi_gpu(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_up_to_2_gpus(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.device_count() > 2:
|
||||
return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
|
||||
Reference in New Issue
Block a user