Run some TF Whisper tests in subprocesses to avoid GPU OOM (#19772)
* Run some TF Whisper tests in subprocesses to avoid GPU OOM Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
@@ -1672,3 +1673,43 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600):
|
||||
"""
|
||||
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
||||
|
||||
Args:
|
||||
test_case (`unittest.TestCase`):
|
||||
The test that will run `target_func`.
|
||||
target_func (`Callable`):
|
||||
The function implementing the actual testing logic.
|
||||
inputs (`dict`, *optional*, defaults to `None`):
|
||||
The inputs that will be passed to `target_func` through an (input) queue.
|
||||
timeout (`int`, *optional*, defaults to 600):
|
||||
The timeout (in seconds) that will be passed to the input and output queues.
|
||||
"""
|
||||
|
||||
start_methohd = "spawn"
|
||||
ctx = multiprocessing.get_context(start_methohd)
|
||||
|
||||
input_queue = ctx.Queue(1)
|
||||
output_queue = ctx.JoinableQueue(1)
|
||||
|
||||
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
|
||||
input_queue.put(inputs, timeout=timeout)
|
||||
|
||||
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
|
||||
process.start()
|
||||
# Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
|
||||
# the test to exit properly.
|
||||
try:
|
||||
results = output_queue.get(timeout=timeout)
|
||||
output_queue.task_done()
|
||||
except Exception as e:
|
||||
process.terminate()
|
||||
test_case.fail(e)
|
||||
process.join(timeout=timeout)
|
||||
|
||||
if results["error"] is not None:
|
||||
test_case.fail(f'{results["error"]}')
|
||||
|
||||
Reference in New Issue
Block a user