[testing] port test_trainer_distributed to distributed pytest + TestCasePlus enhancements (#8107)
* move the helper code into testing_utils * port test_trainer_distributed to work with pytest * improve docs * simplify notes * doc * doc * style * doc * further improvements * torch might not be available * real fix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1,20 +1,16 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
|
||||
from transformers.file_utils import is_datasets_available
|
||||
from transformers.testing_utils import TestCasePlus, slow
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, slow
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import Seq2SeqTrainingArguments, main
|
||||
from .seq2seq_trainer import Seq2SeqTrainer
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
from .utils import execute_async_std
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -166,11 +162,9 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
trainer.train()
|
||||
|
||||
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||
|
||||
# XXX: remove hardcoded path
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
argv = f"""
|
||||
args = f"""
|
||||
--model_name_or_path {model_name}
|
||||
--data_dir {data_dir}
|
||||
--output_dir {output_dir}
|
||||
@@ -204,31 +198,16 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
|
||||
n_gpu = torch.cuda.device_count()
|
||||
if n_gpu > 1:
|
||||
|
||||
path = Path(__file__).resolve()
|
||||
cur_path = path.parents[0]
|
||||
|
||||
path = Path(__file__).resolve()
|
||||
examples_path = path.parents[1]
|
||||
src_path = f"{path.parents[2]}/src"
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
||||
|
||||
distributed_args = (
|
||||
f"-m torch.distributed.launch --nproc_per_node={n_gpu} {cur_path}/finetune_trainer.py".split()
|
||||
)
|
||||
cmd = [sys.executable] + distributed_args + argv
|
||||
|
||||
print("\nRunning: ", " ".join(cmd))
|
||||
|
||||
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
||||
|
||||
assert result.stdout, "produced no output"
|
||||
if result.returncode > 0:
|
||||
pytest.fail(f"failed with returncode {result.returncode}")
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={n_gpu}
|
||||
{self.test_file_dir}/finetune_trainer.py
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
else:
|
||||
# 0 or 1 gpu
|
||||
testargs = ["finetune_trainer.py"] + argv
|
||||
testargs = ["finetune_trainer.py"] + args
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user