[s2s] run_eval/run_eval_search tweaks (#7192)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Stas Bekman
2020-09-17 11:26:38 -07:00
committed by GitHub
parent 9c5bcab5b0
commit efeab6a3f1
2 changed files with 58 additions and 48 deletions

View File

@@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM = "google/pegasus-xsum"
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
@@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase):
return model
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval(model):
def run_eval_tester(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists()
@@ -293,28 +295,39 @@ def test_run_eval(model):
_dump_articles(input_file_name, articles)
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--task",
task,
"--num_beams",
"2",
"--length_penalty",
"2.0",
]
testargs = f"""
run_eval_search.py
{model}
{input_file_name}
{output_file_name}
--score_path {score_path}
--task {task}
--num_beams 2
--length_penalty 2.0
""".split()
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def test_run_eval():
run_eval_tester(T5_TINY)
# any extra models should go into the list here - can be slow
@slow
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY])
def test_run_eval_slow(model):
run_eval_tester(model)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@slow
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
def test_run_eval_search(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
@@ -335,20 +348,17 @@ def test_run_eval_search(model):
_dump_articles(input_file_name, text["en"])
_dump_articles(reference_path, text["de"])
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval_search.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--reference_path",
reference_path,
"--task",
task,
"--search",
"num_beams=1:2 length_penalty=0.9:1.0",
]
testargs = f"""
run_eval_search.py
--model_name {model}
--data_dir {str(input_file_name)}
--save_dir {str(output_file_name)}
--score_path {score_path}
--reference_path {reference_path},
--task {task}
--search num_beams=1:2 length_penalty=0.9:1.0
""".split()
with patch.object(sys, "argv", testargs):
with CaptureStdout() as cs:
run_search()
@@ -367,8 +377,8 @@ def test_run_eval_search(model):
@pytest.mark.parametrize(
["model"],
[pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
"model",
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY],
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
@@ -541,13 +551,13 @@ def test_pack_dataset():
@pytest.mark.parametrize(
["tok_name"],
"tok_name",
[
pytest.param(MBART_TINY),
pytest.param(MARIAN_TINY),
pytest.param(T5_TINY),
pytest.param(BART_TINY),
pytest.param("google/pegasus-xsum"),
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_seq2seq_dataset_truncation(tok_name):
@@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name):
break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()