[s2s run_eval] new features (#7109)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import SummarizationModule, main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .run_eval_search import run_search
|
||||
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
|
||||
@@ -283,7 +284,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
def test_run_eval(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
@@ -311,6 +312,58 @@ def test_run_eval(model):
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
@slow
|
||||
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
|
||||
def test_run_eval_search(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
|
||||
text = {
|
||||
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
|
||||
"de": [
|
||||
"Maschinelles Lernen ist großartig, oder?",
|
||||
"Ich esse gerne Bananen",
|
||||
"Morgen ist wieder ein toller Tag!",
|
||||
],
|
||||
}
|
||||
|
||||
tmp_dir = Path(tempfile.mkdtemp())
|
||||
score_path = str(tmp_dir / "scores.json")
|
||||
reference_path = str(tmp_dir / "val.target")
|
||||
_dump_articles(input_file_name, text["en"])
|
||||
_dump_articles(reference_path, text["de"])
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = [
|
||||
"run_eval_search.py",
|
||||
model,
|
||||
str(input_file_name),
|
||||
str(output_file_name),
|
||||
"--score_path",
|
||||
score_path,
|
||||
"--reference_path",
|
||||
reference_path,
|
||||
"--task",
|
||||
task,
|
||||
"--search",
|
||||
"num_beams=1:2 length_penalty=0.9:1.0",
|
||||
]
|
||||
with patch.object(sys, "argv", testargs):
|
||||
with CaptureStdout() as cs:
|
||||
run_search()
|
||||
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
|
||||
un_expected_strings = ["Info"]
|
||||
if "translation" in task:
|
||||
expected_strings.append("bleu")
|
||||
else:
|
||||
expected_strings.extend(["rouge1", "rouge2", "rougeL"])
|
||||
for w in expected_strings:
|
||||
assert w in cs.out
|
||||
for w in un_expected_strings:
|
||||
assert w not in cs.out
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"],
|
||||
|
||||
Reference in New Issue
Block a user