From fdaf8ab34999c88f89e6e95dca1dac7027731205 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 16 Sep 2020 10:59:57 -0700 Subject: [PATCH] [s2s run_eval] new features (#7109) Co-authored-by: Sam Shleifer --- examples/seq2seq/README.md | 63 +++++++++- examples/seq2seq/run_eval.py | 62 ++++++++-- examples/seq2seq/run_eval_search.py | 139 ++++++++++++++++++++++ examples/seq2seq/test_seq2seq_examples.py | 55 ++++++++- examples/seq2seq/utils.py | 20 +++- 5 files changed, 320 insertions(+), 19 deletions(-) create mode 100644 examples/seq2seq/run_eval_search.py diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index f1c0823311..7223f87c3e 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -46,7 +46,7 @@ export DATA_DIR=${PWD}/wmt_en_de #### Private Data -If you are using your own data, it must be formatted as one directory with 6 files: +If you are using your own data, it must be formatted as one directory with 6 files: ``` train.source train.target @@ -228,6 +228,67 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_ --bs 32 ``` +#### run_eval tips and tricks + +When using `run_eval.py`, the following features can be useful: + +* if you running the script multiple times and want to make it easier to track what arguments produced that output, use `--dump-args`. Along with the results it will also dump any custom params that were passed to the script. For example if you used: `--num_beams 8 --early_stopping true`, the output will be: + ``` + {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True} + ``` + + `--info` is an additional argument available for the same purpose of tracking the conditions of the experiment. It's useful to pass things that weren't in the argument list, e.g. a language pair `--info "lang:en-ru"`. But also if you pass `--info` without a value it will fallback to the current date/time string, e.g. `2020-09-13 18:44:43`. + + If using `--dump-args --info`, the output will be: + + ``` + {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': '2020-09-13 18:44:43'} + ``` + + If using `--dump-args --info "pair:en-ru chkpt=best`, the output will be: + + ``` + {'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': 'pair=en-ru chkpt=best'} + ``` + + +* if you need to perform a parametric search in order to find the best ones that lead to the highest BLEU score, let `run_eval_search.py` to do the searching for you. + + The script accepts the exact same arguments as `run_eval.py`, plus an additional argument `--search`. The value of `--search` is parsed, reformatted and fed to ``run_eval.py`` as additional args. + + The format for the `--search` value is a simple string with hparams and colon separated values to try, e.g.: + ``` + --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false" + ``` + which will generate `12` `(2*3*2)` searches for a product of each hparam. For example the example that was just used will invoke `run_eval.py` repeatedly with: + + ``` + --num_beams 5 --length_penalty 0.8 --early_stopping true + --num_beams 5 --length_penalty 0.8 --early_stopping false + [...] + --num_beams 10 --length_penalty 1.2 --early_stopping false + ``` + + On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments. + +``` +bleu | num_beams | length_penalty | early_stopping +----- | --------- | -------------- | -------------- +26.71 | 5 | 1.1 | 1 +26.66 | 5 | 0.9 | 1 +26.66 | 5 | 0.9 | 0 +26.41 | 5 | 1.1 | 0 +21.94 | 1 | 0.9 | 1 +21.94 | 1 | 0.9 | 0 +21.94 | 1 | 1.1 | 1 +21.94 | 1 | 1.1 | 0 + +Best score args: +stas/wmt19-en-ru data/en-ru/val.source data/en-ru/test_translations.txt --reference_path data/en-ru/val.target --score_path data/en-ru/test_bleu.json --bs 8 --task translation --num_beams 5 --length_penalty 1.1 --early_stopping True +``` + +If you pass `--info "some experiment-specific info"` it will get printed before the results table - this is useful for scripting and multiple runs, so one can tell the different sets of results from each other. + ### DistilBART ![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png) diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index a9adb566e1..23909a3d07 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -1,4 +1,5 @@ import argparse +import datetime import json import time import warnings @@ -15,9 +16,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer logger = getLogger(__name__) try: - from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params + from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params except ImportError: - from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params + from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -72,7 +73,26 @@ def generate_summaries_or_translations( return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4)) -def run_generate(): +def datetime_now(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def run_generate(verbose=True): + """ + + Takes input text, generates output, and then using reference calculates the BLEU scores. + + The results are saved to a file and returned to the caller, and printed out unless ``verbose=False`` is passed. + + Args: + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): print results to stdout + + Returns: + a tuple: ``(scores, params}`` + - ``scores``: a dict of scores data ``{'bleu': 39.6501, 'n_obs': 2000, 'runtime': 186, 'seconds_per_sample': 0.093}`` + - ``params``: a dict of custom params, e.g. ``{'num_beams': 5, 'length_penalty': 0.8}`` + """ + parser = argparse.ArgumentParser() parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") parser.add_argument("input_path", type=str, help="like cnn_dm/test.source") @@ -89,11 +109,19 @@ def run_generate(): "--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all." ) parser.add_argument("--fp16", action="store_true") + parser.add_argument("--dump-args", action="store_true", help="print the custom hparams with the results") + parser.add_argument( + "--info", + nargs="?", + type=str, + const=datetime_now(), + help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.", + ) # Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate args, rest = parser.parse_known_args() - parsed = parse_numeric_cl_kwargs(rest) - if parsed: - print(f"parsed the following generate kwargs: {parsed}") + parsed_args = parse_numeric_n_bool_cl_kwargs(rest) + if parsed_args and verbose: + print(f"parsed the following generate kwargs: {parsed_args}") examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] if args.n_obs > 0: examples = examples[: args.n_obs] @@ -109,23 +137,35 @@ def run_generate(): fp16=args.fp16, task=args.task, prefix=args.prefix, - **parsed, + **parsed_args, ) + if args.reference_path is None: - return + return {} + # Compute scores score_fn = calculate_bleu if "translation" in args.task else calculate_rouge output_lns = [x.rstrip() for x in open(args.save_path).readlines()] reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)] scores: dict = score_fn(output_lns, reference_lns) scores.update(runtime_metrics) - print(scores) + + if args.dump_args: + scores.update(parsed_args) + if args.info: + scores["info"] = args.info + + if verbose: + print(*scores) + if args.score_path is not None: - json.dump(scores, open(args.score_path, "w")) + path = args.score_path + json.dump(scores, open(path, "w")) + return scores if __name__ == "__main__": # Usage for MT: # python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@ - run_generate() + run_generate(verbose=True) diff --git a/examples/seq2seq/run_eval_search.py b/examples/seq2seq/run_eval_search.py new file mode 100644 index 0000000000..a100cb69c0 --- /dev/null +++ b/examples/seq2seq/run_eval_search.py @@ -0,0 +1,139 @@ +import argparse +import itertools +import operator +import sys +from collections import OrderedDict + + +try: + from .run_eval import datetime_now, run_generate +except ImportError: + from run_eval import datetime_now, run_generate + + +# A table of supported tasks and the list of scores in the order of importance to be sorted by. +# To add a new task, simply list the score names that `run_eval.run_generate()` returns +task_score_names = { + "translation": ["bleu"], + "translation_en_to_de": ["bleu"], + "summarization": ["rouge1", "rouge2", "rougeL"], +} + + +def parse_search_arg(search): + groups = search.split() + entries = {k: vs for k, vs in (g.split("=") for g in groups)} + entry_names = list(entries.keys()) + sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()] + matrix = [list(x) for x in itertools.product(*sets)] + return matrix, entry_names + + +def run_search(): + """ + Run parametric search over the desired hparam space with help of ``run_eval.py``. + + All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args. + + The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.: + ``` + --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false" + ``` + which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with: + + ``` + --num_beams 5 --length_penalty 0.8 --early_stopping true + --num_beams 5 --length_penalty 0.8 --early_stopping false + [...] + --num_beams 10 --length_penalty 1.2 --early_stopping false + ``` + + On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments. + + + """ + prog = sys.argv[0] + + parser = argparse.ArgumentParser( + usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list." + ) + parser.add_argument( + "--search", + type=str, + required=False, + help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"', + ) + parser.add_argument( + "--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)" + ) + parser.add_argument( + "--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys() + ) + parser.add_argument( + "--info", + nargs="?", + type=str, + const=datetime_now(), + help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.", + ) + args, args_main = parser.parse_known_args() + # we share some of the args + args_main.extend(["--task", args.task]) + args_normal = [prog] + args_main + + matrix, col_names = parse_search_arg(args.search) + col_names[0:0] = task_score_names[args.task] # score cols first + col_widths = {col: len(str(col)) for col in col_names} + results = [] + for r in matrix: + hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)} + args_exp = " ".join(r).split() + args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM + sys.argv = args_normal + args_exp + + # XXX: need to trap CUDA OOM and lower args.bs if that happens and retry + + scores = run_generate(verbose=False) + # make sure scores are first in the table + result = OrderedDict() + for score in task_score_names[args.task]: + result[score] = scores[score] + result.update(hparams) + results.append(result) + + # find widest entries + for k, v in result.items(): + l = len(str(v)) + if l > col_widths[k]: + col_widths[k] = l + + results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True) + print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names])) + print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names])) + for row in results_sorted: + print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names])) + + best = results_sorted[0] + for score in task_score_names[args.task]: + del best[score] + best_args = [f"--{k} {v}" for k, v in best.items()] + dyn_args = ["--bs", str(args.bs)] + if args.info: + print(f"\nInfo: {args.info}") + print("\nBest score args:") + print(" ".join(args_main + best_args + dyn_args)) + + return results_sorted + + +if __name__ == "__main__": + # Usage: + # [normal-run_eval_search.py cmd plus] \ + # --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false" + # + # Example: + # PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \ + # $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \ + # --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \ + # --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false" + run_search() diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 26fb20bdee..a40572a96e 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -23,6 +23,7 @@ from .distillation import distill_main, evaluate_checkpoint from .finetune import SummarizationModule, main from .pack_dataset import pack_data_dir from .run_eval import generate_summaries_or_translations, run_generate +from .run_eval_search import run_search from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json @@ -283,7 +284,7 @@ class TestSummarizationDistiller(unittest.TestCase): return model -@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)]) +@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)]) def test_run_eval(model): input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" output_file_name = input_file_name.parent / "utest_output.txt" @@ -311,6 +312,58 @@ def test_run_eval(model): assert Path(output_file_name).exists() os.remove(Path(output_file_name)) +@slow +@pytest.mark.parametrize("model", [pytest.param(T5_TINY)]) +def test_run_eval_search(model): + input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" + output_file_name = input_file_name.parent / "utest_output.txt" + assert not output_file_name.exists() + + text = { + "en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"], + "de": [ + "Maschinelles Lernen ist großartig, oder?", + "Ich esse gerne Bananen", + "Morgen ist wieder ein toller Tag!", + ], + } + + tmp_dir = Path(tempfile.mkdtemp()) + score_path = str(tmp_dir / "scores.json") + reference_path = str(tmp_dir / "val.target") + _dump_articles(input_file_name, text["en"]) + _dump_articles(reference_path, text["de"]) + task = "translation_en_to_de" if model == T5_TINY else "summarization" + testargs = [ + "run_eval_search.py", + model, + str(input_file_name), + str(output_file_name), + "--score_path", + score_path, + "--reference_path", + reference_path, + "--task", + task, + "--search", + "num_beams=1:2 length_penalty=0.9:1.0", + ] + with patch.object(sys, "argv", testargs): + with CaptureStdout() as cs: + run_search() + expected_strings = [" num_beams | length_penalty", model, "Best score args"] + un_expected_strings = ["Info"] + if "translation" in task: + expected_strings.append("bleu") + else: + expected_strings.extend(["rouge1", "rouge2", "rougeL"]) + for w in expected_strings: + assert w in cs.out + for w in un_expected_strings: + assert w not in cs.out + assert Path(output_file_name).exists() + os.remove(Path(output_file_name)) + @pytest.mark.parametrize( ["model"], diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index c049b4372e..b1d35cf10e 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -377,18 +377,26 @@ def assert_not_all_frozen(model): # CLI Parsing utils -def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]: - """Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric.""" +def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]: + """ + Parse an argv list of unspecified command line args to a dict. + Assumes all values are either numeric or boolean in the form of true/false. + """ result = {} assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}" num_pairs = len(unparsed_args) // 2 for pair_num in range(num_pairs): i = 2 * pair_num assert unparsed_args[i].startswith("--") - try: - value = int(unparsed_args[i + 1]) - except ValueError: - value = float(unparsed_args[i + 1]) # this can raise another informative ValueError + if unparsed_args[i + 1].lower() == "true": + value = True + elif unparsed_args[i + 1].lower() == "false": + value = False + else: + try: + value = int(unparsed_args[i + 1]) + except ValueError: + value = float(unparsed_args[i + 1]) # this can raise another informative ValueError result[unparsed_args[i][2:]] = value return result