[seq2seq testing] multigpu test run via subprocess (#7281)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
261
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
Normal file
261
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_multigpu
|
||||
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"max_tokens_per_batch": None,
|
||||
"supervise_forward": True,
|
||||
"normalize_hidden": True,
|
||||
"label_smoothing": 0.2,
|
||||
"eval_max_gen_length": None,
|
||||
"eval_beams": 1,
|
||||
"val_metric": "loss",
|
||||
"save_top_k": 1,
|
||||
"adafactor": True,
|
||||
"early_stopping_patience": 2,
|
||||
"logger_name": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
"task": "summarization",
|
||||
"num_workers": 2,
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
"tgt_suffix": "",
|
||||
"resume_from_checkpoint": None,
|
||||
"sortish_sampler": True,
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||
"n_tpu_cores": 0,
|
||||
"max_grad_norm": 1.0,
|
||||
"do_train": True,
|
||||
"do_predict": True,
|
||||
"accumulate_grad_batches": 1,
|
||||
"server_ip": "",
|
||||
"server_port": "",
|
||||
"seed": 42,
|
||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||
"config_name": "",
|
||||
"tokenizer_name": "facebook/bart-large",
|
||||
"do_lower_case": False,
|
||||
"learning_rate": 0.3,
|
||||
"lr_scheduler": "linear",
|
||||
"weight_decay": 0.0,
|
||||
"adam_epsilon": 1e-08,
|
||||
"warmup_steps": 0,
|
||||
"max_epochs": 1,
|
||||
"train_batch_size": 2,
|
||||
"eval_batch_size": 2,
|
||||
"max_source_length": 12,
|
||||
"max_target_length": 12,
|
||||
"val_max_target_length": 12,
|
||||
"test_max_target_length": 12,
|
||||
"fast_dev_run": False,
|
||||
"no_cache": False,
|
||||
"n_train": -1,
|
||||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
}
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
content = "\n".join(articles)
|
||||
Path(path).open("w").writelines(content)
|
||||
|
||||
|
||||
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
|
||||
|
||||
def make_test_data_dir(tmp_dir):
|
||||
for split in ["train", "val", "test"]:
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
# XXX: a candidate for testing_utils (python>=3.6)
|
||||
# https://stackoverflow.com/a/59041913/9201239
|
||||
import asyncio # noqa
|
||||
|
||||
|
||||
class RunOutput:
|
||||
def __init__(self, returncode, stdout, stderr):
|
||||
self.returncode = returncode
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
async def _read_stream(stream, callback):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if line:
|
||||
callback(line)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
|
||||
if echo:
|
||||
print(cmd)
|
||||
|
||||
p = await asyncio.create_subprocess_exec(
|
||||
cmd[0],
|
||||
*cmd[1:],
|
||||
stdin=stdin,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
out = []
|
||||
err = []
|
||||
|
||||
def tee(line, sink, pipe, label=""):
|
||||
line = line.decode("utf-8").rstrip()
|
||||
sink.append(line)
|
||||
if not quiet:
|
||||
print(label, line, file=pipe)
|
||||
|
||||
await asyncio.wait(
|
||||
[
|
||||
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
|
||||
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
|
||||
],
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
||||
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
||||
#
|
||||
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
|
||||
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
|
||||
return RunOutput(await p.wait(), out, err)
|
||||
|
||||
|
||||
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> RunOutput:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = loop.run_until_complete(
|
||||
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
@require_torch_multigpu
|
||||
def test_multigpu(self):
|
||||
|
||||
updates = dict(
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
overwrite_output_dir=True,
|
||||
sortish_sampler=True,
|
||||
)
|
||||
self._test_distiller_cli_fork(updates, check_contents=False)
|
||||
|
||||
def _test_distiller_cli_fork(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
label_smoothing=0.0,
|
||||
early_stopping_patience=-1,
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
max_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
|
||||
def convert(k, v):
|
||||
if k in ["tgt_suffix", "server_ip", "server_port", "out", "n_tpu_cores"]:
|
||||
return ""
|
||||
if v is False or v is None:
|
||||
return ""
|
||||
if v is True: # or len(str(v))==0:
|
||||
return f"--{k}"
|
||||
return f"--{k}={v}"
|
||||
|
||||
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
|
||||
cmd = [sys.executable, "./examples/seq2seq/distillation.py"] + cli_args
|
||||
|
||||
print("\nRunning: ", " ".join(cmd))
|
||||
|
||||
path = Path(__file__).resolve()
|
||||
examples_path = path.parents[1]
|
||||
src_path = f"{path.parents[2]}/src"
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
||||
|
||||
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
||||
|
||||
assert result.stdout, "produced no output"
|
||||
if result.returncode > 0:
|
||||
pytest.fail(f"failed with returncode {result.returncode}")
|
||||
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
ckpt_files = [p for p in contents if p.endswith("ckpt")]
|
||||
assert len(ckpt_files) > 0
|
||||
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
|
||||
# get the following from the module, (we don't have access to `model` here)
|
||||
metrics_save_path = os.path.join(output_dir, "metrics.json")
|
||||
val_metric = "rouge2"
|
||||
|
||||
metrics = load_json(metrics_save_path)
|
||||
# {'test': [{'test_avg_loss': 10.63731575012207, 'test_avg_rouge1': 0.0, 'test_avg_rouge2': 0.0, 'test_avg_rougeL': 0.0, 'test_avg_gen_time': 0.1822289228439331, 'test_avg_gen_len': 142.0, 'step_count': 1}]}
|
||||
print(metrics)
|
||||
last_step_stats = metrics["val"][-1]
|
||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
Reference in New Issue
Block a user