[s2s] run_eval/run_eval_search tweaks (#7192)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
BERT_BASE_CASED = "bert-base-cased"
|
||||
PEGASUS_XSUM = "google/pegasus-xsum"
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
@@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
def test_run_eval(model):
|
||||
def run_eval_tester(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
@@ -293,28 +295,39 @@ def test_run_eval(model):
|
||||
_dump_articles(input_file_name, articles)
|
||||
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = [
|
||||
"run_eval.py",
|
||||
model,
|
||||
str(input_file_name),
|
||||
str(output_file_name),
|
||||
"--score_path",
|
||||
score_path,
|
||||
"--task",
|
||||
task,
|
||||
"--num_beams",
|
||||
"2",
|
||||
"--length_penalty",
|
||||
"2.0",
|
||||
]
|
||||
testargs = f"""
|
||||
run_eval_search.py
|
||||
{model}
|
||||
{input_file_name}
|
||||
{output_file_name}
|
||||
--score_path {score_path}
|
||||
--task {task}
|
||||
--num_beams 2
|
||||
--length_penalty 2.0
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
|
||||
# test one model to quickly (no-@slow) catch simple problems and do an
|
||||
# extensive testing of functionality with multiple models as @slow separately
|
||||
def test_run_eval():
|
||||
run_eval_tester(T5_TINY)
|
||||
|
||||
|
||||
# any extra models should go into the list here - can be slow
|
||||
@slow
|
||||
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
|
||||
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY])
|
||||
def test_run_eval_slow(model):
|
||||
run_eval_tester(model)
|
||||
|
||||
|
||||
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
|
||||
@slow
|
||||
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
|
||||
def test_run_eval_search(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
@@ -335,20 +348,17 @@ def test_run_eval_search(model):
|
||||
_dump_articles(input_file_name, text["en"])
|
||||
_dump_articles(reference_path, text["de"])
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = [
|
||||
"run_eval_search.py",
|
||||
model,
|
||||
str(input_file_name),
|
||||
str(output_file_name),
|
||||
"--score_path",
|
||||
score_path,
|
||||
"--reference_path",
|
||||
reference_path,
|
||||
"--task",
|
||||
task,
|
||||
"--search",
|
||||
"num_beams=1:2 length_penalty=0.9:1.0",
|
||||
]
|
||||
testargs = f"""
|
||||
run_eval_search.py
|
||||
--model_name {model}
|
||||
--data_dir {str(input_file_name)}
|
||||
--save_dir {str(output_file_name)}
|
||||
--score_path {score_path}
|
||||
--reference_path {reference_path},
|
||||
--task {task}
|
||||
--search num_beams=1:2 length_penalty=0.9:1.0
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
with CaptureStdout() as cs:
|
||||
run_search()
|
||||
@@ -367,8 +377,8 @@ def test_run_eval_search(model):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"],
|
||||
[pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
|
||||
"model",
|
||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY],
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
@@ -541,13 +551,13 @@ def test_pack_dataset():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["tok_name"],
|
||||
"tok_name",
|
||||
[
|
||||
pytest.param(MBART_TINY),
|
||||
pytest.param(MARIAN_TINY),
|
||||
pytest.param(T5_TINY),
|
||||
pytest.param(BART_TINY),
|
||||
pytest.param("google/pegasus-xsum"),
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_seq2seq_dataset_truncation(tok_name):
|
||||
@@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name):
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
|
||||
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
|
||||
def test_legacy_dataset_truncation(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
|
||||
Reference in New Issue
Block a user