[s2s run_eval] new features (#7109)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Stas Bekman
2020-09-16 10:59:57 -07:00
committed by GitHub
parent df165065c3
commit fdaf8ab349
5 changed files with 320 additions and 19 deletions

View File

@@ -23,6 +23,7 @@ from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .run_eval_search import run_search
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
@@ -283,7 +284,7 @@ class TestSummarizationDistiller(unittest.TestCase):
return model
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
@@ -311,6 +312,58 @@ def test_run_eval(model):
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
@slow
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
def test_run_eval_search(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists()
text = {
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
"de": [
"Maschinelles Lernen ist großartig, oder?",
"Ich esse gerne Bananen",
"Morgen ist wieder ein toller Tag!",
],
}
tmp_dir = Path(tempfile.mkdtemp())
score_path = str(tmp_dir / "scores.json")
reference_path = str(tmp_dir / "val.target")
_dump_articles(input_file_name, text["en"])
_dump_articles(reference_path, text["de"])
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval_search.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--reference_path",
reference_path,
"--task",
task,
"--search",
"num_beams=1:2 length_penalty=0.9:1.0",
]
with patch.object(sys, "argv", testargs):
with CaptureStdout() as cs:
run_search()
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
un_expected_strings = ["Info"]
if "translation" in task:
expected_strings.append("bleu")
else:
expected_strings.extend(["rouge1", "rouge2", "rougeL"])
for w in expected_strings:
assert w in cs.out
for w in un_expected_strings:
assert w not in cs.out
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
@pytest.mark.parametrize(
["model"],