diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 821fd8c545..2c9d9dccfe 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1728,7 +1728,7 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d return decorator -def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600): +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): """ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. @@ -1739,9 +1739,12 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600): The function implementing the actual testing logic. inputs (`dict`, *optional*, defaults to `None`): The inputs that will be passed to `target_func` through an (input) queue. - timeout (`int`, *optional*, defaults to 600): - The timeout (in seconds) that will be passed to the input and output queues. + timeout (`int`, *optional*, defaults to `None`): + The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. + variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. """ + if timeout is None: + timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) start_methohd = "spawn" ctx = multiprocessing.get_context(start_methohd) diff --git a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py index 33388eb6d3..508d96ae10 100644 --- a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py @@ -15,7 +15,6 @@ import inspect import math import multiprocessing -import os import traceback import unittest @@ -637,7 +636,4 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): @require_pyctcdecode @require_librosa def test_wav2vec2_with_lm_invalid_pool(self): - timeout = os.environ.get("PYTEST_TIMEOUT", 600) - run_test_in_subprocess( - test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout - ) + run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None) diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 42946fce49..2e3c2c26c8 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -19,7 +19,6 @@ import glob import inspect import math import multiprocessing -import os import traceback import unittest @@ -682,7 +681,4 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): @require_pyctcdecode @require_librosa def test_wav2vec2_with_lm_invalid_pool(self): - timeout = os.environ.get("PYTEST_TIMEOUT", 600) - run_test_in_subprocess( - test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout - ) + run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None) diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 35df9fc223..f649257a83 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -1713,10 +1713,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): @require_pyctcdecode @require_torchaudio def test_wav2vec2_with_lm_invalid_pool(self): - timeout = os.environ.get("PYTEST_TIMEOUT", 600) - run_test_in_subprocess( - test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout - ) + run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None) def test_inference_diarization(self): model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device) diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 3bc04e56a6..4d3ce0bf38 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -15,7 +15,6 @@ """ Testing suite for the TensorFlow Whisper model. """ import inspect -import os import tempfile import traceback import unittest @@ -891,10 +890,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): @slow def test_large_logits_librispeech(self): - timeout = os.environ.get("PYTEST_TIMEOUT", 600) - run_test_in_subprocess( - test_case=self, target_func=_test_large_logits_librispeech, inputs=None, timeout=timeout - ) + run_test_in_subprocess(test_case=self, target_func=_test_large_logits_librispeech, inputs=None) @slow def test_tiny_en_generation(self): @@ -959,22 +955,15 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): @slow def test_large_generation(self): - timeout = os.environ.get("PYTEST_TIMEOUT", 600) - run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None, timeout=timeout) + run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None) @slow def test_large_generation_multilingual(self): - timeout = os.environ.get("PYTEST_TIMEOUT", 600) - run_test_in_subprocess( - test_case=self, target_func=_test_large_generation_multilingual, inputs=None, timeout=timeout - ) + run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None) @slow def test_large_batched_generation(self): - timeout = os.environ.get("PYTEST_TIMEOUT", 600) - run_test_in_subprocess( - test_case=self, target_func=_test_large_batched_generation, inputs=None, timeout=timeout - ) + run_test_in_subprocess(test_case=self, target_func=_test_large_batched_generation, inputs=None) @slow def test_tiny_en_batched_generation(self):