[testing] skip decorators: docs, tests, bugs (#7334)
* skip decorators: docs, tests, bugs * another important note * style * bloody style * add @pytest.mark.parametrize * add note * no idea what it wants :(
This commit is contained in:
@@ -62,8 +62,9 @@ def slow(test_case):
|
||||
|
||||
"""
|
||||
if not _run_slow_tests:
|
||||
test_case = unittest.skip("test is slow")(test_case)
|
||||
return test_case
|
||||
return unittest.skip("test is slow")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def custom_tokenizers(test_case):
|
||||
@@ -75,8 +76,9 @@ def custom_tokenizers(test_case):
|
||||
to a truthy value to run them.
|
||||
"""
|
||||
if not _run_custom_tokenizers:
|
||||
test_case = unittest.skip("test of custom tokenizers")(test_case)
|
||||
return test_case
|
||||
return unittest.skip("test of custom tokenizers")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
@@ -87,8 +89,9 @@ def require_torch(test_case):
|
||||
|
||||
"""
|
||||
if not _torch_available:
|
||||
test_case = unittest.skip("test requires PyTorch")(test_case)
|
||||
return test_case
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_tf(test_case):
|
||||
@@ -99,8 +102,9 @@ def require_tf(test_case):
|
||||
|
||||
"""
|
||||
if not _tf_available:
|
||||
test_case = unittest.skip("test requires TensorFlow")(test_case)
|
||||
return test_case
|
||||
return unittest.skip("test requires TensorFlow")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_multigpu(test_case):
|
||||
@@ -119,7 +123,8 @@ def require_multigpu(test_case):
|
||||
|
||||
if torch.cuda.device_count() < 2:
|
||||
return unittest.skip("test requires multiple GPUs")(test_case)
|
||||
return test_case
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_non_multigpu(test_case):
|
||||
@@ -133,7 +138,8 @@ def require_non_multigpu(test_case):
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
return unittest.skip("test requires 0 or 1 GPU")(test_case)
|
||||
return test_case
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_tpu(test_case):
|
||||
@@ -142,8 +148,8 @@ def require_torch_tpu(test_case):
|
||||
"""
|
||||
if not _torch_tpu_available:
|
||||
return unittest.skip("test requires PyTorch TPU")
|
||||
|
||||
return test_case
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
if _torch_available:
|
||||
@@ -154,9 +160,9 @@ else:
|
||||
|
||||
|
||||
def require_torch_and_cuda(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch). """
|
||||
"""Decorator marking a test that requires CUDA and PyTorch. """
|
||||
if torch_device != "cuda":
|
||||
return unittest.skip("test requires CUDA")
|
||||
return unittest.skip("test requires CUDA")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
@@ -165,15 +171,17 @@ def require_datasets(test_case):
|
||||
"""Decorator marking a test that requires datasets."""
|
||||
|
||||
if not _datasets_available:
|
||||
test_case = unittest.skip("test requires Datasets")(test_case)
|
||||
return test_case
|
||||
return unittest.skip("test requires `datasets`")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_faiss(test_case):
|
||||
"""Decorator marking a test that requires faiss."""
|
||||
if not _faiss_available:
|
||||
test_case = unittest.skip("test requires Faiss")(test_case)
|
||||
return test_case
|
||||
return unittest.skip("test requires `faiss`")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def get_tests_dir():
|
||||
|
||||
Reference in New Issue
Block a user