[s2s trainer] tests to use distributed on multi-gpu machine (#7965)

This commit is contained in:
Stas Bekman
2020-10-22 14:26:22 -07:00
committed by GitHub
parent 64b24bb3c2
commit 023f0f3708
3 changed files with 121 additions and 78 deletions

View File

@@ -1,15 +1,23 @@
import os
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
from transformers import is_torch_available
from transformers.testing_utils import TestCasePlus, slow
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY
from .utils import execute_async_std
if is_torch_available():
import torch
set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@@ -25,7 +33,7 @@ class TestFinetuneTrainer(TestCasePlus):
@slow
def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10)
# Check metrics
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
@@ -43,6 +51,8 @@ class TestFinetuneTrainer(TestCasePlus):
assert "test_results.json" in contents
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
# XXX: remove hardcoded path
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
argv = f"""
@@ -77,8 +87,34 @@ class TestFinetuneTrainer(TestCasePlus):
""".split()
# --eval_beams 2
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
n_gpu = torch.cuda.device_count()
if n_gpu > 1:
path = Path(__file__).resolve()
cur_path = path.parents[0]
path = Path(__file__).resolve()
examples_path = path.parents[1]
src_path = f"{path.parents[2]}/src"
env = os.environ.copy()
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
distributed_args = (
f"-m torch.distributed.launch --nproc_per_node={n_gpu} {cur_path}/finetune_trainer.py".split()
)
cmd = [sys.executable] + distributed_args + argv
print("\nRunning: ", " ".join(cmd))
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
assert result.stdout, "produced no output"
if result.returncode > 0:
pytest.fail(f"failed with returncode {result.returncode}")
else:
# 0 or 1 gpu
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
return output_dir