these tests require non-multigpu env (#7059)
* these tests require non-multigpu env * cleanup * clarify
This commit is contained in:
@@ -122,6 +122,20 @@ def require_multigpu(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_non_multigpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
||||
"""
|
||||
if not _torch_available:
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
return unittest.skip("test requires 0 or 1 GPU")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
|
||||
Reference in New Issue
Block a user