diff --git a/examples/seq2seq/run_eval_search.py b/examples/seq2seq/run_eval_search.py index a100cb69c0..ae221d37b6 100644 --- a/examples/seq2seq/run_eval_search.py +++ b/examples/seq2seq/run_eval_search.py @@ -15,7 +15,6 @@ except ImportError: # To add a new task, simply list the score names that `run_eval.run_generate()` returns task_score_names = { "translation": ["bleu"], - "translation_en_to_de": ["bleu"], "summarization": ["rouge1", "rouge2", "rougeL"], } @@ -66,9 +65,7 @@ def run_search(): parser.add_argument( "--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)" ) - parser.add_argument( - "--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys() - ) + parser.add_argument("--task", type=str, help="used for task_specific_params + metrics") parser.add_argument( "--info", nargs="?", @@ -81,8 +78,11 @@ def run_search(): args_main.extend(["--task", args.task]) args_normal = [prog] + args_main + # to support variations like translation_en_to_de" + task = "translation" if "translation" in args.task else "summarization" + matrix, col_names = parse_search_arg(args.search) - col_names[0:0] = task_score_names[args.task] # score cols first + col_names[0:0] = task_score_names[task] # score cols first col_widths = {col: len(str(col)) for col in col_names} results = [] for r in matrix: @@ -96,7 +96,7 @@ def run_search(): scores = run_generate(verbose=False) # make sure scores are first in the table result = OrderedDict() - for score in task_score_names[args.task]: + for score in task_score_names[task]: result[score] = scores[score] result.update(hparams) results.append(result) @@ -107,14 +107,14 @@ def run_search(): if l > col_widths[k]: col_widths[k] = l - results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True) + results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True) print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names])) print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names])) for row in results_sorted: print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names])) best = results_sorted[0] - for score in task_score_names[args.task]: + for score in task_score_names[task]: del best[score] best_args = [f"--{k} {v}" for k, v in best.items()] dyn_args = ["--bs", str(args.bs)] diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index ab14d288b5..1f664ce288 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random" BART_TINY = "sshleifer/bart-tiny-random" MBART_TINY = "sshleifer/tiny-mbart" MARIAN_TINY = "sshleifer/tiny-marian-en-de" +BERT_BASE_CASED = "bert-base-cased" +PEGASUS_XSUM = "google/pegasus-xsum" + stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks @@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase): return model -@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)]) -def test_run_eval(model): +def run_eval_tester(model): input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" output_file_name = input_file_name.parent / "utest_output.txt" assert not output_file_name.exists() @@ -293,28 +295,39 @@ def test_run_eval(model): _dump_articles(input_file_name, articles) score_path = str(Path(tempfile.mkdtemp()) / "scores.json") task = "translation_en_to_de" if model == T5_TINY else "summarization" - testargs = [ - "run_eval.py", - model, - str(input_file_name), - str(output_file_name), - "--score_path", - score_path, - "--task", - task, - "--num_beams", - "2", - "--length_penalty", - "2.0", - ] + testargs = f""" + run_eval_search.py + {model} + {input_file_name} + {output_file_name} + --score_path {score_path} + --task {task} + --num_beams 2 + --length_penalty 2.0 + """.split() + with patch.object(sys, "argv", testargs): run_generate() assert Path(output_file_name).exists() os.remove(Path(output_file_name)) +# test one model to quickly (no-@slow) catch simple problems and do an +# extensive testing of functionality with multiple models as @slow separately +def test_run_eval(): + run_eval_tester(T5_TINY) + + +# any extra models should go into the list here - can be slow @slow -@pytest.mark.parametrize("model", [pytest.param(T5_TINY)]) +@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY]) +def test_run_eval_slow(model): + run_eval_tester(model) + + +# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart) +@slow +@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY]) def test_run_eval_search(model): input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" output_file_name = input_file_name.parent / "utest_output.txt" @@ -335,20 +348,17 @@ def test_run_eval_search(model): _dump_articles(input_file_name, text["en"]) _dump_articles(reference_path, text["de"]) task = "translation_en_to_de" if model == T5_TINY else "summarization" - testargs = [ - "run_eval_search.py", - model, - str(input_file_name), - str(output_file_name), - "--score_path", - score_path, - "--reference_path", - reference_path, - "--task", - task, - "--search", - "num_beams=1:2 length_penalty=0.9:1.0", - ] + testargs = f""" + run_eval_search.py + --model_name {model} + --data_dir {str(input_file_name)} + --save_dir {str(output_file_name)} + --score_path {score_path} + --reference_path {reference_path}, + --task {task} + --search num_beams=1:2 length_penalty=0.9:1.0 + """.split() + with patch.object(sys, "argv", testargs): with CaptureStdout() as cs: run_search() @@ -367,8 +377,8 @@ def test_run_eval_search(model): @pytest.mark.parametrize( - ["model"], - [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)], + "model", + [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY], ) def test_finetune(model): args_d: dict = CHEAP_ARGS.copy() @@ -541,13 +551,13 @@ def test_pack_dataset(): @pytest.mark.parametrize( - ["tok_name"], + "tok_name", [ - pytest.param(MBART_TINY), - pytest.param(MARIAN_TINY), - pytest.param(T5_TINY), - pytest.param(BART_TINY), - pytest.param("google/pegasus-xsum"), + MBART_TINY, + MARIAN_TINY, + T5_TINY, + BART_TINY, + PEGASUS_XSUM, ], ) def test_seq2seq_dataset_truncation(tok_name): @@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name): break # No need to test every batch -@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")]) +@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED]) def test_legacy_dataset_truncation(tok): tokenizer = AutoTokenizer.from_pretrained(tok) tmp_dir = make_test_data_dir()