[s2s] test_distributed_eval (#8315)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -3,7 +3,14 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multigpu
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
require_torch_gpu,
|
||||
require_torch_multigpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_seq2seq_examples import CHEAP_ARGS, make_test_data_dir
|
||||
from .utils import load_json
|
||||
@@ -80,3 +87,30 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_distributed_eval(self):
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
--model_name Helsinki-NLP/opus-mt-en-ro
|
||||
--save_dir {output_dir}
|
||||
--data_dir test_data/wmt_en_ro
|
||||
--num_beams 2
|
||||
--task translation
|
||||
""".split()
|
||||
|
||||
# we want this test to run even if there is only one GPU, but if there are more we use them all
|
||||
n_gpu = get_gpu_count()
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={n_gpu}
|
||||
{self.test_file_dir}/run_distributed_eval.py
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
metrics_save_path = os.path.join(output_dir, "test_bleu.json")
|
||||
metrics = load_json(metrics_save_path)
|
||||
# print(metrics)
|
||||
self.assertGreaterEqual(metrics["bleu"], 25)
|
||||
|
||||
Reference in New Issue
Block a user