[s2s] run_eval.py QOL improvements and cleanup(#6746)
This commit is contained in:
@@ -252,13 +252,24 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
def test_run_eval_bart(model):
|
||||
def test_run_eval(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
_dump_articles(input_file_name, articles)
|
||||
testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path
|
||||
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = [
|
||||
"run_eval.py",
|
||||
model,
|
||||
str(input_file_name),
|
||||
str(output_file_name),
|
||||
"--score_path",
|
||||
score_path,
|
||||
"--task",
|
||||
task,
|
||||
]
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
assert Path(output_file_name).exists()
|
||||
|
||||
Reference in New Issue
Block a user