Update all require decorators to use skipUnless when possible (#16999)
This commit is contained in:
@@ -203,10 +203,7 @@ def slow(test_case):
|
|||||||
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _run_slow_tests:
|
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||||
return unittest.skip("test is slow")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def tooslow(test_case):
|
def tooslow(test_case):
|
||||||
@@ -227,10 +224,7 @@ def custom_tokenizers(test_case):
|
|||||||
Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
|
Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
|
||||||
environment variable to a truthy value to run them.
|
environment variable to a truthy value to run them.
|
||||||
"""
|
"""
|
||||||
if not _run_custom_tokenizers:
|
return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
|
||||||
return unittest.skip("test of custom tokenizers")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_git_lfs(test_case):
|
def require_git_lfs(test_case):
|
||||||
@@ -240,34 +234,22 @@ def require_git_lfs(test_case):
|
|||||||
git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment
|
git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment
|
||||||
variable to a truthy value to run them.
|
variable to a truthy value to run them.
|
||||||
"""
|
"""
|
||||||
if not _run_git_lfs_tests:
|
return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
|
||||||
return unittest.skip("test of git lfs workflow")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_rjieba(test_case):
|
def require_rjieba(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_rjieba_available():
|
return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
|
||||||
return unittest.skip("test requires rjieba")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_tf2onnx(test_case):
|
def require_tf2onnx(test_case):
|
||||||
if not is_tf2onnx_available():
|
return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
|
||||||
return unittest.skip("test requires tf2onnx")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_onnx(test_case):
|
def require_onnx(test_case):
|
||||||
if not is_onnx_available():
|
return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
|
||||||
return unittest.skip("test requires ONNX")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_timm(test_case):
|
def require_timm(test_case):
|
||||||
@@ -277,10 +259,7 @@ def require_timm(test_case):
|
|||||||
These tests are skipped when Timm isn't installed.
|
These tests are skipped when Timm isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_timm_available():
|
return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
|
||||||
return unittest.skip("test requires Timm")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torch(test_case):
|
def require_torch(test_case):
|
||||||
@@ -290,10 +269,7 @@ def require_torch(test_case):
|
|||||||
These tests are skipped when PyTorch isn't installed.
|
These tests are skipped when PyTorch isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_torch_available():
|
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||||
return unittest.skip("test requires PyTorch")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torch_scatter(test_case):
|
def require_torch_scatter(test_case):
|
||||||
@@ -303,10 +279,7 @@ def require_torch_scatter(test_case):
|
|||||||
These tests are skipped when PyTorch scatter isn't installed.
|
These tests are skipped when PyTorch scatter isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_scatter_available():
|
return unittest.skipUnless(is_scatter_available(), "test requires PyTorch scatter")(test_case)
|
||||||
return unittest.skip("test requires PyTorch scatter")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_tensorflow_probability(test_case):
|
def require_tensorflow_probability(test_case):
|
||||||
@@ -316,89 +289,65 @@ def require_tensorflow_probability(test_case):
|
|||||||
These tests are skipped when TensorFlow probability isn't installed.
|
These tests are skipped when TensorFlow probability isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_tensorflow_probability_available():
|
return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
|
||||||
return unittest.skip("test requires TensorFlow probability")(test_case)
|
test_case
|
||||||
else:
|
)
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torchaudio(test_case):
|
def require_torchaudio(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
|
Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_torchaudio_available():
|
return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
|
||||||
return unittest.skip("test requires torchaudio")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_tf(test_case):
|
def require_tf(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
|
Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_tf_available():
|
return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)
|
||||||
return unittest.skip("test requires TensorFlow")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_flax(test_case):
|
def require_flax(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||||
"""
|
"""
|
||||||
if not is_flax_available():
|
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||||
test_case = unittest.skip("test requires JAX & Flax")(test_case)
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_sentencepiece(test_case):
|
def require_sentencepiece(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
|
Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_sentencepiece_available():
|
return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
|
||||||
return unittest.skip("test requires SentencePiece")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_scipy(test_case):
|
def require_scipy(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
|
Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_scipy_available():
|
return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
|
||||||
return unittest.skip("test requires Scipy")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_tokenizers(test_case):
|
def require_tokenizers(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
|
Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_tokenizers_available():
|
return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
|
||||||
return unittest.skip("test requires tokenizers")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_pandas(test_case):
|
def require_pandas(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
|
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_pandas_available():
|
return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
|
||||||
return unittest.skip("test requires pandas")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_pytesseract(test_case):
|
def require_pytesseract(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
|
Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_pytesseract_available():
|
return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
|
||||||
return unittest.skip("test requires PyTesseract")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_scatter(test_case):
|
def require_scatter(test_case):
|
||||||
@@ -406,10 +355,7 @@ def require_scatter(test_case):
|
|||||||
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
|
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
|
||||||
installed.
|
installed.
|
||||||
"""
|
"""
|
||||||
if not is_scatter_available():
|
return unittest.skipUnless(is_scatter_available(), "test requires PyTorch Scatter")(test_case)
|
||||||
return unittest.skip("test requires PyTorch Scatter")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_pytorch_quantization(test_case):
|
def require_pytorch_quantization(test_case):
|
||||||
@@ -417,10 +363,9 @@ def require_pytorch_quantization(test_case):
|
|||||||
Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
|
Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
|
||||||
Quantization Toolkit isn't installed.
|
Quantization Toolkit isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_pytorch_quantization_available():
|
return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
|
||||||
return unittest.skip("test requires PyTorch Quantization Toolkit")(test_case)
|
test_case
|
||||||
else:
|
)
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_vision(test_case):
|
def require_vision(test_case):
|
||||||
@@ -428,30 +373,21 @@ def require_vision(test_case):
|
|||||||
Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
|
Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
|
||||||
installed.
|
installed.
|
||||||
"""
|
"""
|
||||||
if not is_vision_available():
|
return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
|
||||||
return unittest.skip("test requires vision")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_ftfy(test_case):
|
def require_ftfy(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
|
Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_ftfy_available():
|
return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
|
||||||
return unittest.skip("test requires ftfy")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_spacy(test_case):
|
def require_spacy(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
|
Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
|
||||||
"""
|
"""
|
||||||
if not is_spacy_available():
|
return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
|
||||||
return unittest.skip("test requires spacy")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torch_multi_gpu(test_case):
|
def require_torch_multi_gpu(test_case):
|
||||||
@@ -466,10 +402,7 @@ def require_torch_multi_gpu(test_case):
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if torch.cuda.device_count() < 2:
|
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||||
return unittest.skip("test requires multiple GPUs")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torch_non_multi_gpu(test_case):
|
def require_torch_non_multi_gpu(test_case):
|
||||||
@@ -481,10 +414,7 @@ def require_torch_non_multi_gpu(test_case):
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
|
||||||
return unittest.skip("test requires 0 or 1 GPU")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torch_up_to_2_gpus(test_case):
|
def require_torch_up_to_2_gpus(test_case):
|
||||||
@@ -496,20 +426,14 @@ def require_torch_up_to_2_gpus(test_case):
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if torch.cuda.device_count() > 2:
|
return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
|
||||||
return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torch_tpu(test_case):
|
def require_torch_tpu(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires a TPU (in PyTorch).
|
Decorator marking a test that requires a TPU (in PyTorch).
|
||||||
"""
|
"""
|
||||||
if not is_torch_tpu_available():
|
return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case)
|
||||||
return unittest.skip("test requires PyTorch TPU")
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -533,42 +457,31 @@ else:
|
|||||||
|
|
||||||
def require_torch_gpu(test_case):
|
def require_torch_gpu(test_case):
|
||||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||||
if torch_device != "cuda":
|
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||||
return unittest.skip("test requires CUDA")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torch_bf16(test_case):
|
def require_torch_bf16(test_case):
|
||||||
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
|
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
|
||||||
if not is_torch_bf16_available():
|
return unittest.skipUnless(
|
||||||
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
|
is_torch_bf16_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
|
||||||
else:
|
)(test_case)
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_torch_tf32(test_case):
|
def require_torch_tf32(test_case):
|
||||||
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
|
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
|
||||||
if not is_torch_tf32_available():
|
return unittest.skipUnless(
|
||||||
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
|
is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
|
||||||
else:
|
)(test_case)
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_detectron2(test_case):
|
def require_detectron2(test_case):
|
||||||
"""Decorator marking a test that requires detectron2."""
|
"""Decorator marking a test that requires detectron2."""
|
||||||
if not is_detectron2_available():
|
return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
|
||||||
return unittest.skip("test requires `detectron2`")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_faiss(test_case):
|
def require_faiss(test_case):
|
||||||
"""Decorator marking a test that requires faiss."""
|
"""Decorator marking a test that requires faiss."""
|
||||||
if not is_faiss_available():
|
return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
|
||||||
return unittest.skip("test requires `faiss`")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_optuna(test_case):
|
def require_optuna(test_case):
|
||||||
@@ -578,10 +491,7 @@ def require_optuna(test_case):
|
|||||||
These tests are skipped when optuna isn't installed.
|
These tests are skipped when optuna isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_optuna_available():
|
return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
|
||||||
return unittest.skip("test requires optuna")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_ray(test_case):
|
def require_ray(test_case):
|
||||||
@@ -591,10 +501,7 @@ def require_ray(test_case):
|
|||||||
These tests are skipped when Ray/tune isn't installed.
|
These tests are skipped when Ray/tune isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_ray_available():
|
return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
|
||||||
return unittest.skip("test requires Ray/tune")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_sigopt(test_case):
|
def require_sigopt(test_case):
|
||||||
@@ -604,10 +511,7 @@ def require_sigopt(test_case):
|
|||||||
These tests are skipped when SigOpt isn't installed.
|
These tests are skipped when SigOpt isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_sigopt_available():
|
return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
|
||||||
return unittest.skip("test requires SigOpt")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_wandb(test_case):
|
def require_wandb(test_case):
|
||||||
@@ -617,10 +521,7 @@ def require_wandb(test_case):
|
|||||||
These tests are skipped when wandb isn't installed.
|
These tests are skipped when wandb isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_wandb_available():
|
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
|
||||||
return unittest.skip("test requires wandb")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_soundfile(test_case):
|
def require_soundfile(test_case):
|
||||||
@@ -630,80 +531,56 @@ def require_soundfile(test_case):
|
|||||||
These tests are skipped when soundfile isn't installed.
|
These tests are skipped when soundfile isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_soundfile_availble():
|
return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case)
|
||||||
return unittest.skip("test requires soundfile")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_deepspeed(test_case):
|
def require_deepspeed(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires deepspeed
|
Decorator marking a test that requires deepspeed
|
||||||
"""
|
"""
|
||||||
if not is_deepspeed_available():
|
return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
|
||||||
return unittest.skip("test requires deepspeed")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_fairscale(test_case):
|
def require_fairscale(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires fairscale
|
Decorator marking a test that requires fairscale
|
||||||
"""
|
"""
|
||||||
if not is_fairscale_available():
|
return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
|
||||||
return unittest.skip("test requires fairscale")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_apex(test_case):
|
def require_apex(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires apex
|
Decorator marking a test that requires apex
|
||||||
"""
|
"""
|
||||||
if not is_apex_available():
|
return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
|
||||||
return unittest.skip("test requires apex")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_bitsandbytes(test_case):
|
def require_bitsandbytes(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator for bits and bytes (bnb) dependency
|
Decorator for bits and bytes (bnb) dependency
|
||||||
"""
|
"""
|
||||||
if not is_bitsandbytes_available():
|
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case)
|
||||||
return unittest.skip("test requires bnb")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_phonemizer(test_case):
|
def require_phonemizer(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires phonemizer
|
Decorator marking a test that requires phonemizer
|
||||||
"""
|
"""
|
||||||
if not is_phonemizer_available():
|
return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
|
||||||
return unittest.skip("test requires phonemizer")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_pyctcdecode(test_case):
|
def require_pyctcdecode(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires pyctcdecode
|
Decorator marking a test that requires pyctcdecode
|
||||||
"""
|
"""
|
||||||
if not is_pyctcdecode_available():
|
return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
|
||||||
return unittest.skip("test requires pyctcdecode")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_librosa(test_case):
|
def require_librosa(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires librosa
|
Decorator marking a test that requires librosa
|
||||||
"""
|
"""
|
||||||
if not is_librosa_available():
|
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
|
||||||
return unittest.skip("test requires librosa")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def cmd_exists(cmd):
|
def cmd_exists(cmd):
|
||||||
@@ -714,10 +591,7 @@ def require_usr_bin_time(test_case):
|
|||||||
"""
|
"""
|
||||||
Decorator marking a test that requires `/usr/bin/time`
|
Decorator marking a test that requires `/usr/bin/time`
|
||||||
"""
|
"""
|
||||||
if not cmd_exists("/usr/bin/time"):
|
return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
|
||||||
return unittest.skip("test requires /usr/bin/time")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_count():
|
def get_gpu_count():
|
||||||
|
|||||||
Reference in New Issue
Block a user