[s2s trainer] tests to use distributed on multi-gpu machine (#7965)
This commit is contained in:
@@ -6,11 +6,15 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_multigpu
|
||||
|
||||
from .utils import load_json
|
||||
from .utils import execute_async_std, load_json
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -106,73 +110,6 @@ def make_test_data_dir(tmp_dir):
|
||||
return tmp_dir
|
||||
|
||||
|
||||
# XXX: a candidate for testing_utils (python>=3.6)
|
||||
# https://stackoverflow.com/a/59041913/9201239
|
||||
import asyncio # noqa
|
||||
|
||||
|
||||
class RunOutput:
|
||||
def __init__(self, returncode, stdout, stderr):
|
||||
self.returncode = returncode
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
async def _read_stream(stream, callback):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if line:
|
||||
callback(line)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
|
||||
if echo:
|
||||
print(cmd)
|
||||
|
||||
p = await asyncio.create_subprocess_exec(
|
||||
cmd[0],
|
||||
*cmd[1:],
|
||||
stdin=stdin,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
out = []
|
||||
err = []
|
||||
|
||||
def tee(line, sink, pipe, label=""):
|
||||
line = line.decode("utf-8").rstrip()
|
||||
sink.append(line)
|
||||
if not quiet:
|
||||
print(label, line, file=pipe)
|
||||
|
||||
await asyncio.wait(
|
||||
[
|
||||
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
|
||||
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
|
||||
],
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
||||
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
||||
#
|
||||
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
|
||||
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
|
||||
return RunOutput(await p.wait(), out, err)
|
||||
|
||||
|
||||
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = loop.run_until_complete(
|
||||
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -220,17 +157,18 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
return f"--{k}"
|
||||
return f"--{k}={v}"
|
||||
|
||||
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
|
||||
cmd = [sys.executable, "./examples/seq2seq/distillation.py"] + cli_args
|
||||
|
||||
print("\nRunning: ", " ".join(cmd))
|
||||
|
||||
path = Path(__file__).resolve()
|
||||
cur_path = path.parents[0]
|
||||
examples_path = path.parents[1]
|
||||
src_path = f"{path.parents[2]}/src"
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
||||
|
||||
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
|
||||
cmd = [sys.executable, f"{cur_path}/distillation.py"] + cli_args
|
||||
|
||||
print("\nRunning: ", " ".join(cmd))
|
||||
|
||||
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
||||
|
||||
assert result.stdout, "produced no output"
|
||||
|
||||
Reference in New Issue
Block a user