From 57e6464ac9a31156f1c93e59107323e6ec01309e Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Fri, 29 Apr 2022 08:55:38 -0400 Subject: [PATCH] Update all require decorators to use skipUnless when possible (#16999) --- src/transformers/testing_utils.py | 232 +++++++----------------------- 1 file changed, 53 insertions(+), 179 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 36f56d2eeb..6e4546afb1 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -203,10 +203,7 @@ def slow(test_case): Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ - if not _run_slow_tests: - return unittest.skip("test is slow")(test_case) - else: - return test_case + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) def tooslow(test_case): @@ -227,10 +224,7 @@ def custom_tokenizers(test_case): Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS environment variable to a truthy value to run them. """ - if not _run_custom_tokenizers: - return unittest.skip("test of custom tokenizers")(test_case) - else: - return test_case + return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case) def require_git_lfs(test_case): @@ -240,34 +234,22 @@ def require_git_lfs(test_case): git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment variable to a truthy value to run them. """ - if not _run_git_lfs_tests: - return unittest.skip("test of git lfs workflow")(test_case) - else: - return test_case + return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case) def require_rjieba(test_case): """ Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. """ - if not is_rjieba_available(): - return unittest.skip("test requires rjieba")(test_case) - else: - return test_case + return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case) def require_tf2onnx(test_case): - if not is_tf2onnx_available(): - return unittest.skip("test requires tf2onnx")(test_case) - else: - return test_case + return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case) def require_onnx(test_case): - if not is_onnx_available(): - return unittest.skip("test requires ONNX")(test_case) - else: - return test_case + return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case) def require_timm(test_case): @@ -277,10 +259,7 @@ def require_timm(test_case): These tests are skipped when Timm isn't installed. """ - if not is_timm_available(): - return unittest.skip("test requires Timm")(test_case) - else: - return test_case + return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case) def require_torch(test_case): @@ -290,10 +269,7 @@ def require_torch(test_case): These tests are skipped when PyTorch isn't installed. """ - if not is_torch_available(): - return unittest.skip("test requires PyTorch")(test_case) - else: - return test_case + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) def require_torch_scatter(test_case): @@ -303,10 +279,7 @@ def require_torch_scatter(test_case): These tests are skipped when PyTorch scatter isn't installed. """ - if not is_scatter_available(): - return unittest.skip("test requires PyTorch scatter")(test_case) - else: - return test_case + return unittest.skipUnless(is_scatter_available(), "test requires PyTorch scatter")(test_case) def require_tensorflow_probability(test_case): @@ -316,89 +289,65 @@ def require_tensorflow_probability(test_case): These tests are skipped when TensorFlow probability isn't installed. """ - if not is_tensorflow_probability_available(): - return unittest.skip("test requires TensorFlow probability")(test_case) - else: - return test_case + return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")( + test_case + ) def require_torchaudio(test_case): """ Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed. """ - if not is_torchaudio_available(): - return unittest.skip("test requires torchaudio")(test_case) - else: - return test_case + return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case) def require_tf(test_case): """ Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed. """ - if not is_tf_available(): - return unittest.skip("test requires TensorFlow")(test_case) - else: - return test_case + return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case) def require_flax(test_case): """ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed """ - if not is_flax_available(): - test_case = unittest.skip("test requires JAX & Flax")(test_case) - return test_case + return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) def require_sentencepiece(test_case): """ Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. """ - if not is_sentencepiece_available(): - return unittest.skip("test requires SentencePiece")(test_case) - else: - return test_case + return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case) def require_scipy(test_case): """ Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed. """ - if not is_scipy_available(): - return unittest.skip("test requires Scipy")(test_case) - else: - return test_case + return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case) def require_tokenizers(test_case): """ Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed. """ - if not is_tokenizers_available(): - return unittest.skip("test requires tokenizers")(test_case) - else: - return test_case + return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case) def require_pandas(test_case): """ Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. """ - if not is_pandas_available(): - return unittest.skip("test requires pandas")(test_case) - else: - return test_case + return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case) def require_pytesseract(test_case): """ Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed. """ - if not is_pytesseract_available(): - return unittest.skip("test requires PyTesseract")(test_case) - else: - return test_case + return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case) def require_scatter(test_case): @@ -406,10 +355,7 @@ def require_scatter(test_case): Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't installed. """ - if not is_scatter_available(): - return unittest.skip("test requires PyTorch Scatter")(test_case) - else: - return test_case + return unittest.skipUnless(is_scatter_available(), "test requires PyTorch Scatter")(test_case) def require_pytorch_quantization(test_case): @@ -417,10 +363,9 @@ def require_pytorch_quantization(test_case): Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch Quantization Toolkit isn't installed. """ - if not is_pytorch_quantization_available(): - return unittest.skip("test requires PyTorch Quantization Toolkit")(test_case) - else: - return test_case + return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")( + test_case + ) def require_vision(test_case): @@ -428,30 +373,21 @@ def require_vision(test_case): Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't installed. """ - if not is_vision_available(): - return unittest.skip("test requires vision")(test_case) - else: - return test_case + return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case) def require_ftfy(test_case): """ Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed. """ - if not is_ftfy_available(): - return unittest.skip("test requires ftfy")(test_case) - else: - return test_case + return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case) def require_spacy(test_case): """ Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed. """ - if not is_spacy_available(): - return unittest.skip("test requires spacy")(test_case) - else: - return test_case + return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case) def require_torch_multi_gpu(test_case): @@ -466,10 +402,7 @@ def require_torch_multi_gpu(test_case): import torch - if torch.cuda.device_count() < 2: - return unittest.skip("test requires multiple GPUs")(test_case) - else: - return test_case + return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) def require_torch_non_multi_gpu(test_case): @@ -481,10 +414,7 @@ def require_torch_non_multi_gpu(test_case): import torch - if torch.cuda.device_count() > 1: - return unittest.skip("test requires 0 or 1 GPU")(test_case) - else: - return test_case + return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case) def require_torch_up_to_2_gpus(test_case): @@ -496,20 +426,14 @@ def require_torch_up_to_2_gpus(test_case): import torch - if torch.cuda.device_count() > 2: - return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case) - else: - return test_case + return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case) def require_torch_tpu(test_case): """ Decorator marking a test that requires a TPU (in PyTorch). """ - if not is_torch_tpu_available(): - return unittest.skip("test requires PyTorch TPU") - else: - return test_case + return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case) if is_torch_available(): @@ -533,42 +457,31 @@ else: def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" - if torch_device != "cuda": - return unittest.skip("test requires CUDA")(test_case) - else: - return test_case + return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) def require_torch_bf16(test_case): """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10.""" - if not is_torch_bf16_available(): - return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case) - else: - return test_case + return unittest.skipUnless( + is_torch_bf16_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10" + )(test_case) def require_torch_tf32(test_case): """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7.""" - if not is_torch_tf32_available(): - return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case) - else: - return test_case + return unittest.skipUnless( + is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7" + )(test_case) def require_detectron2(test_case): """Decorator marking a test that requires detectron2.""" - if not is_detectron2_available(): - return unittest.skip("test requires `detectron2`")(test_case) - else: - return test_case + return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case) def require_faiss(test_case): """Decorator marking a test that requires faiss.""" - if not is_faiss_available(): - return unittest.skip("test requires `faiss`")(test_case) - else: - return test_case + return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case) def require_optuna(test_case): @@ -578,10 +491,7 @@ def require_optuna(test_case): These tests are skipped when optuna isn't installed. """ - if not is_optuna_available(): - return unittest.skip("test requires optuna")(test_case) - else: - return test_case + return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case) def require_ray(test_case): @@ -591,10 +501,7 @@ def require_ray(test_case): These tests are skipped when Ray/tune isn't installed. """ - if not is_ray_available(): - return unittest.skip("test requires Ray/tune")(test_case) - else: - return test_case + return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case) def require_sigopt(test_case): @@ -604,10 +511,7 @@ def require_sigopt(test_case): These tests are skipped when SigOpt isn't installed. """ - if not is_sigopt_available(): - return unittest.skip("test requires SigOpt")(test_case) - else: - return test_case + return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case) def require_wandb(test_case): @@ -617,10 +521,7 @@ def require_wandb(test_case): These tests are skipped when wandb isn't installed. """ - if not is_wandb_available(): - return unittest.skip("test requires wandb")(test_case) - else: - return test_case + return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case) def require_soundfile(test_case): @@ -630,80 +531,56 @@ def require_soundfile(test_case): These tests are skipped when soundfile isn't installed. """ - if not is_soundfile_availble(): - return unittest.skip("test requires soundfile")(test_case) - else: - return test_case + return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case) def require_deepspeed(test_case): """ Decorator marking a test that requires deepspeed """ - if not is_deepspeed_available(): - return unittest.skip("test requires deepspeed")(test_case) - else: - return test_case + return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case) def require_fairscale(test_case): """ Decorator marking a test that requires fairscale """ - if not is_fairscale_available(): - return unittest.skip("test requires fairscale")(test_case) - else: - return test_case + return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case) def require_apex(test_case): """ Decorator marking a test that requires apex """ - if not is_apex_available(): - return unittest.skip("test requires apex")(test_case) - else: - return test_case + return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case) def require_bitsandbytes(test_case): """ Decorator for bits and bytes (bnb) dependency """ - if not is_bitsandbytes_available(): - return unittest.skip("test requires bnb")(test_case) - else: - return test_case + return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case) def require_phonemizer(test_case): """ Decorator marking a test that requires phonemizer """ - if not is_phonemizer_available(): - return unittest.skip("test requires phonemizer")(test_case) - else: - return test_case + return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case) def require_pyctcdecode(test_case): """ Decorator marking a test that requires pyctcdecode """ - if not is_pyctcdecode_available(): - return unittest.skip("test requires pyctcdecode")(test_case) - else: - return test_case + return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case) def require_librosa(test_case): """ Decorator marking a test that requires librosa """ - if not is_librosa_available(): - return unittest.skip("test requires librosa")(test_case) - else: - return test_case + return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) def cmd_exists(cmd): @@ -714,10 +591,7 @@ def require_usr_bin_time(test_case): """ Decorator marking a test that requires `/usr/bin/time` """ - if not cmd_exists("/usr/bin/time"): - return unittest.skip("test requires /usr/bin/time")(test_case) - else: - return test_case + return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case) def get_gpu_count():