Fix env. variable type issue in testing (#21609)

* fix env issue

* fix env issue

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-02-13 20:53:26 +01:00
committed by GitHub
parent 5987e0ab69
commit cbecf121cd
5 changed files with 13 additions and 32 deletions

View File

@@ -1728,7 +1728,7 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
return decorator return decorator
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600): def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
""" """
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
@@ -1739,9 +1739,12 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600):
The function implementing the actual testing logic. The function implementing the actual testing logic.
inputs (`dict`, *optional*, defaults to `None`): inputs (`dict`, *optional*, defaults to `None`):
The inputs that will be passed to `target_func` through an (input) queue. The inputs that will be passed to `target_func` through an (input) queue.
timeout (`int`, *optional*, defaults to 600): timeout (`int`, *optional*, defaults to `None`):
The timeout (in seconds) that will be passed to the input and output queues. The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
""" """
if timeout is None:
timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
start_methohd = "spawn" start_methohd = "spawn"
ctx = multiprocessing.get_context(start_methohd) ctx = multiprocessing.get_context(start_methohd)

View File

@@ -15,7 +15,6 @@
import inspect import inspect
import math import math
import multiprocessing import multiprocessing
import os
import traceback import traceback
import unittest import unittest
@@ -637,7 +636,4 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_librosa @require_librosa
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
)

View File

@@ -19,7 +19,6 @@ import glob
import inspect import inspect
import math import math
import multiprocessing import multiprocessing
import os
import traceback import traceback
import unittest import unittest
@@ -682,7 +681,4 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_librosa @require_librosa
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
)

View File

@@ -1713,10 +1713,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_torchaudio @require_torchaudio
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
)
def test_inference_diarization(self): def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device) model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)

View File

@@ -15,7 +15,6 @@
""" Testing suite for the TensorFlow Whisper model. """ """ Testing suite for the TensorFlow Whisper model. """
import inspect import inspect
import os
import tempfile import tempfile
import traceback import traceback
import unittest import unittest
@@ -891,10 +890,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_logits_librispeech(self): def test_large_logits_librispeech(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_large_logits_librispeech, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_large_logits_librispeech, inputs=None, timeout=timeout
)
@slow @slow
def test_tiny_en_generation(self): def test_tiny_en_generation(self):
@@ -959,22 +955,15 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_generation(self): def test_large_generation(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None)
run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None, timeout=timeout)
@slow @slow
def test_large_generation_multilingual(self): def test_large_generation_multilingual(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_large_generation_multilingual, inputs=None, timeout=timeout
)
@slow @slow
def test_large_batched_generation(self): def test_large_batched_generation(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_large_batched_generation, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_large_batched_generation, inputs=None, timeout=timeout
)
@slow @slow
def test_tiny_en_batched_generation(self): def test_tiny_en_batched_generation(self):