[testing] port test_trainer_distributed to distributed pytest + TestCasePlus enhancements (#8107)

* move the helper code into testing_utils

* port test_trainer_distributed to work with pytest

* improve docs

* simplify notes

* doc

* doc

* style

* doc

* further improvements

* torch might not be available

* real fix

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2020-10-28 08:51:32 -07:00
committed by GitHub
parent 47dfa65b0c
commit 5423f2a9d4
6 changed files with 308 additions and 148 deletions

View File

@@ -1,20 +1,16 @@
import os
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, slow
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, slow
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer
from .test_seq2seq_examples import MBART_TINY
from .utils import execute_async_std
if is_torch_available():
@@ -166,11 +162,9 @@ class TestFinetuneTrainer(TestCasePlus):
trainer.train()
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
# XXX: remove hardcoded path
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
argv = f"""
args = f"""
--model_name_or_path {model_name}
--data_dir {data_dir}
--output_dir {output_dir}
@@ -204,31 +198,16 @@ class TestFinetuneTrainer(TestCasePlus):
n_gpu = torch.cuda.device_count()
if n_gpu > 1:
path = Path(__file__).resolve()
cur_path = path.parents[0]
path = Path(__file__).resolve()
examples_path = path.parents[1]
src_path = f"{path.parents[2]}/src"
env = os.environ.copy()
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
distributed_args = (
f"-m torch.distributed.launch --nproc_per_node={n_gpu} {cur_path}/finetune_trainer.py".split()
)
cmd = [sys.executable] + distributed_args + argv
print("\nRunning: ", " ".join(cmd))
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
assert result.stdout, "produced no output"
if result.returncode > 0:
pytest.fail(f"failed with returncode {result.returncode}")
distributed_args = f"""
-m torch.distributed.launch
--nproc_per_node={n_gpu}
{self.test_file_dir}/finetune_trainer.py
""".split()
cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
else:
# 0 or 1 gpu
testargs = ["finetune_trainer.py"] + argv
testargs = ["finetune_trainer.py"] + args
with patch.object(sys, "argv", testargs):
main()

View File

@@ -5,12 +5,10 @@ import os
import sys
from pathlib import Path
import pytest
from transformers import is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch_multigpu
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multigpu
from .utils import execute_async_std, load_json
from .utils import load_json
if is_torch_available():
@@ -157,23 +155,9 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
return f"--{k}"
return f"--{k}={v}"
path = Path(__file__).resolve()
cur_path = path.parents[0]
examples_path = path.parents[1]
src_path = f"{path.parents[2]}/src"
env = os.environ.copy()
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
cmd = [sys.executable, f"{cur_path}/distillation.py"] + cli_args
print("\nRunning: ", " ".join(cmd))
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
assert result.stdout, "produced no output"
if result.returncode > 0:
pytest.fail(f"failed with returncode {result.returncode}")
cmd = [sys.executable, f"{self.test_file_dir}/distillation.py"] + cli_args
execute_subprocess_async(cmd, env=self.get_env())
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}

View File

@@ -5,7 +5,6 @@ import math
import os
import pickle
import socket
import sys
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Tuple, Union
@@ -641,74 +640,6 @@ def check_output_dir(args, expected_items=0):
):
raise ValueError(
f"Output directory ({args.output_dir}) already exists and "
"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
"Use --overwrite_output_dir to overcome."
)
# the following code deals with async io between processes
# adapted from https://stackoverflow.com/a/59041913/9201239
import asyncio # noqa
class _RunOutput:
def __init__(self, returncode, stdout, stderr):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
async def _read_stream(stream, callback):
while True:
line = await stream.readline()
if line:
callback(line)
else:
break
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
if echo:
print(cmd)
p = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdin=stdin,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
out = []
err = []
def tee(line, sink, pipe, label=""):
line = line.decode("utf-8").rstrip()
sink.append(line)
if not quiet:
print(label, line, file=pipe)
await asyncio.wait(
[
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
],
timeout=timeout,
)
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
#
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
return _RunOutput(await p.wait(), out, err)
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
loop = asyncio.get_event_loop()
result = loop.run_until_complete(
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
)
return result