[s2s trainer] tests to use distributed on multi-gpu machine (#7965)

This commit is contained in:
Stas Bekman
2020-10-22 14:26:22 -07:00
committed by GitHub
parent 64b24bb3c2
commit 023f0f3708
3 changed files with 121 additions and 78 deletions

View File

@@ -5,6 +5,7 @@ import math
import os
import pickle
import socket
import sys
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Tuple, Union
@@ -643,3 +644,71 @@ def check_output_dir(args, expected_items=0):
"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
"Use --overwrite_output_dir to overcome."
)
# the following code deals with async io between processes
# adapted from https://stackoverflow.com/a/59041913/9201239
import asyncio # noqa
class _RunOutput:
def __init__(self, returncode, stdout, stderr):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
async def _read_stream(stream, callback):
while True:
line = await stream.readline()
if line:
callback(line)
else:
break
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
if echo:
print(cmd)
p = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdin=stdin,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
out = []
err = []
def tee(line, sink, pipe, label=""):
line = line.decode("utf-8").rstrip()
sink.append(line)
if not quiet:
print(label, line, file=pipe)
await asyncio.wait(
[
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
],
timeout=timeout,
)
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
#
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
return _RunOutput(await p.wait(), out, err)
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
loop = asyncio.get_event_loop()
result = loop.run_until_complete(
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
)
return result