[s2s trainer] tests to use distributed on multi-gpu machine (#7965)
This commit is contained in:
@@ -1,15 +1,23 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import TestCasePlus, slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
from .finetune_trainer import main
|
from .finetune_trainer import main
|
||||||
from .test_seq2seq_examples import MBART_TINY
|
from .test_seq2seq_examples import MBART_TINY
|
||||||
|
from .utils import execute_async_std
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
@@ -25,7 +33,7 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
@slow
|
@slow
|
||||||
def test_finetune_trainer_slow(self):
|
def test_finetune_trainer_slow(self):
|
||||||
# There is a missing call to __init__process_group somewhere
|
# There is a missing call to __init__process_group somewhere
|
||||||
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10)
|
||||||
|
|
||||||
# Check metrics
|
# Check metrics
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
@@ -43,6 +51,8 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
assert "test_results.json" in contents
|
assert "test_results.json" in contents
|
||||||
|
|
||||||
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||||
|
|
||||||
|
# XXX: remove hardcoded path
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
argv = f"""
|
argv = f"""
|
||||||
@@ -77,8 +87,34 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
""".split()
|
""".split()
|
||||||
# --eval_beams 2
|
# --eval_beams 2
|
||||||
|
|
||||||
testargs = ["finetune_trainer.py"] + argv
|
n_gpu = torch.cuda.device_count()
|
||||||
with patch.object(sys, "argv", testargs):
|
if n_gpu > 1:
|
||||||
main()
|
|
||||||
|
path = Path(__file__).resolve()
|
||||||
|
cur_path = path.parents[0]
|
||||||
|
|
||||||
|
path = Path(__file__).resolve()
|
||||||
|
examples_path = path.parents[1]
|
||||||
|
src_path = f"{path.parents[2]}/src"
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
||||||
|
|
||||||
|
distributed_args = (
|
||||||
|
f"-m torch.distributed.launch --nproc_per_node={n_gpu} {cur_path}/finetune_trainer.py".split()
|
||||||
|
)
|
||||||
|
cmd = [sys.executable] + distributed_args + argv
|
||||||
|
|
||||||
|
print("\nRunning: ", " ".join(cmd))
|
||||||
|
|
||||||
|
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
||||||
|
|
||||||
|
assert result.stdout, "produced no output"
|
||||||
|
if result.returncode > 0:
|
||||||
|
pytest.fail(f"failed with returncode {result.returncode}")
|
||||||
|
else:
|
||||||
|
# 0 or 1 gpu
|
||||||
|
testargs = ["finetune_trainer.py"] + argv
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
main()
|
||||||
|
|
||||||
return output_dir
|
return output_dir
|
||||||
|
|||||||
@@ -6,11 +6,15 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import TestCasePlus, require_torch_multigpu
|
from transformers.testing_utils import TestCasePlus, require_torch_multigpu
|
||||||
|
|
||||||
from .utils import load_json
|
from .utils import execute_async_std, load_json
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -106,73 +110,6 @@ def make_test_data_dir(tmp_dir):
|
|||||||
return tmp_dir
|
return tmp_dir
|
||||||
|
|
||||||
|
|
||||||
# XXX: a candidate for testing_utils (python>=3.6)
|
|
||||||
# https://stackoverflow.com/a/59041913/9201239
|
|
||||||
import asyncio # noqa
|
|
||||||
|
|
||||||
|
|
||||||
class RunOutput:
|
|
||||||
def __init__(self, returncode, stdout, stderr):
|
|
||||||
self.returncode = returncode
|
|
||||||
self.stdout = stdout
|
|
||||||
self.stderr = stderr
|
|
||||||
|
|
||||||
|
|
||||||
async def _read_stream(stream, callback):
|
|
||||||
while True:
|
|
||||||
line = await stream.readline()
|
|
||||||
if line:
|
|
||||||
callback(line)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
|
|
||||||
if echo:
|
|
||||||
print(cmd)
|
|
||||||
|
|
||||||
p = await asyncio.create_subprocess_exec(
|
|
||||||
cmd[0],
|
|
||||||
*cmd[1:],
|
|
||||||
stdin=stdin,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
env=env,
|
|
||||||
)
|
|
||||||
out = []
|
|
||||||
err = []
|
|
||||||
|
|
||||||
def tee(line, sink, pipe, label=""):
|
|
||||||
line = line.decode("utf-8").rstrip()
|
|
||||||
sink.append(line)
|
|
||||||
if not quiet:
|
|
||||||
print(label, line, file=pipe)
|
|
||||||
|
|
||||||
await asyncio.wait(
|
|
||||||
[
|
|
||||||
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
|
|
||||||
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
|
|
||||||
],
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
|
||||||
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
|
||||||
#
|
|
||||||
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
|
|
||||||
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
|
|
||||||
return RunOutput(await p.wait(), out, err)
|
|
||||||
|
|
||||||
|
|
||||||
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
result = loop.run_until_complete(
|
|
||||||
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -220,17 +157,18 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
|||||||
return f"--{k}"
|
return f"--{k}"
|
||||||
return f"--{k}={v}"
|
return f"--{k}={v}"
|
||||||
|
|
||||||
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
|
|
||||||
cmd = [sys.executable, "./examples/seq2seq/distillation.py"] + cli_args
|
|
||||||
|
|
||||||
print("\nRunning: ", " ".join(cmd))
|
|
||||||
|
|
||||||
path = Path(__file__).resolve()
|
path = Path(__file__).resolve()
|
||||||
|
cur_path = path.parents[0]
|
||||||
examples_path = path.parents[1]
|
examples_path = path.parents[1]
|
||||||
src_path = f"{path.parents[2]}/src"
|
src_path = f"{path.parents[2]}/src"
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
||||||
|
|
||||||
|
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
|
||||||
|
cmd = [sys.executable, f"{cur_path}/distillation.py"] + cli_args
|
||||||
|
|
||||||
|
print("\nRunning: ", " ".join(cmd))
|
||||||
|
|
||||||
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
||||||
|
|
||||||
assert result.stdout, "produced no output"
|
assert result.stdout, "produced no output"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import socket
|
import socket
|
||||||
|
import sys
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||||
@@ -643,3 +644,71 @@ def check_output_dir(args, expected_items=0):
|
|||||||
"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
|
"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
|
||||||
"Use --overwrite_output_dir to overcome."
|
"Use --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# the following code deals with async io between processes
|
||||||
|
|
||||||
|
# adapted from https://stackoverflow.com/a/59041913/9201239
|
||||||
|
import asyncio # noqa
|
||||||
|
|
||||||
|
|
||||||
|
class _RunOutput:
|
||||||
|
def __init__(self, returncode, stdout, stderr):
|
||||||
|
self.returncode = returncode
|
||||||
|
self.stdout = stdout
|
||||||
|
self.stderr = stderr
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_stream(stream, callback):
|
||||||
|
while True:
|
||||||
|
line = await stream.readline()
|
||||||
|
if line:
|
||||||
|
callback(line)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
|
||||||
|
if echo:
|
||||||
|
print(cmd)
|
||||||
|
|
||||||
|
p = await asyncio.create_subprocess_exec(
|
||||||
|
cmd[0],
|
||||||
|
*cmd[1:],
|
||||||
|
stdin=stdin,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
out = []
|
||||||
|
err = []
|
||||||
|
|
||||||
|
def tee(line, sink, pipe, label=""):
|
||||||
|
line = line.decode("utf-8").rstrip()
|
||||||
|
sink.append(line)
|
||||||
|
if not quiet:
|
||||||
|
print(label, line, file=pipe)
|
||||||
|
|
||||||
|
await asyncio.wait(
|
||||||
|
[
|
||||||
|
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
|
||||||
|
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
|
||||||
|
],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
||||||
|
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
||||||
|
#
|
||||||
|
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
|
||||||
|
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
|
||||||
|
return _RunOutput(await p.wait(), out, err)
|
||||||
|
|
||||||
|
|
||||||
|
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = loop.run_until_complete(
|
||||||
|
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user