[s2s run_eval] new features (#7109)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -228,6 +228,67 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
|
||||
--bs 32
|
||||
```
|
||||
|
||||
#### run_eval tips and tricks
|
||||
|
||||
When using `run_eval.py`, the following features can be useful:
|
||||
|
||||
* if you running the script multiple times and want to make it easier to track what arguments produced that output, use `--dump-args`. Along with the results it will also dump any custom params that were passed to the script. For example if you used: `--num_beams 8 --early_stopping true`, the output will be:
|
||||
```
|
||||
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True}
|
||||
```
|
||||
|
||||
`--info` is an additional argument available for the same purpose of tracking the conditions of the experiment. It's useful to pass things that weren't in the argument list, e.g. a language pair `--info "lang:en-ru"`. But also if you pass `--info` without a value it will fallback to the current date/time string, e.g. `2020-09-13 18:44:43`.
|
||||
|
||||
If using `--dump-args --info`, the output will be:
|
||||
|
||||
```
|
||||
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': '2020-09-13 18:44:43'}
|
||||
```
|
||||
|
||||
If using `--dump-args --info "pair:en-ru chkpt=best`, the output will be:
|
||||
|
||||
```
|
||||
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': 'pair=en-ru chkpt=best'}
|
||||
```
|
||||
|
||||
|
||||
* if you need to perform a parametric search in order to find the best ones that lead to the highest BLEU score, let `run_eval_search.py` to do the searching for you.
|
||||
|
||||
The script accepts the exact same arguments as `run_eval.py`, plus an additional argument `--search`. The value of `--search` is parsed, reformatted and fed to ``run_eval.py`` as additional args.
|
||||
|
||||
The format for the `--search` value is a simple string with hparams and colon separated values to try, e.g.:
|
||||
```
|
||||
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
|
||||
```
|
||||
which will generate `12` `(2*3*2)` searches for a product of each hparam. For example the example that was just used will invoke `run_eval.py` repeatedly with:
|
||||
|
||||
```
|
||||
--num_beams 5 --length_penalty 0.8 --early_stopping true
|
||||
--num_beams 5 --length_penalty 0.8 --early_stopping false
|
||||
[...]
|
||||
--num_beams 10 --length_penalty 1.2 --early_stopping false
|
||||
```
|
||||
|
||||
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
|
||||
|
||||
```
|
||||
bleu | num_beams | length_penalty | early_stopping
|
||||
----- | --------- | -------------- | --------------
|
||||
26.71 | 5 | 1.1 | 1
|
||||
26.66 | 5 | 0.9 | 1
|
||||
26.66 | 5 | 0.9 | 0
|
||||
26.41 | 5 | 1.1 | 0
|
||||
21.94 | 1 | 0.9 | 1
|
||||
21.94 | 1 | 0.9 | 0
|
||||
21.94 | 1 | 1.1 | 1
|
||||
21.94 | 1 | 1.1 | 0
|
||||
|
||||
Best score args:
|
||||
stas/wmt19-en-ru data/en-ru/val.source data/en-ru/test_translations.txt --reference_path data/en-ru/val.target --score_path data/en-ru/test_bleu.json --bs 8 --task translation --num_beams 5 --length_penalty 1.1 --early_stopping True
|
||||
```
|
||||
|
||||
If you pass `--info "some experiment-specific info"` it will get printed before the results table - this is useful for scripting and multiple runs, so one can tell the different sets of results from each other.
|
||||
|
||||
|
||||
### DistilBART
|
||||

|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
import warnings
|
||||
@@ -15,9 +16,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
|
||||
from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||
except ImportError:
|
||||
from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
|
||||
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -72,7 +73,26 @@ def generate_summaries_or_translations(
|
||||
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
|
||||
|
||||
|
||||
def run_generate():
|
||||
def datetime_now():
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def run_generate(verbose=True):
|
||||
"""
|
||||
|
||||
Takes input text, generates output, and then using reference calculates the BLEU scores.
|
||||
|
||||
The results are saved to a file and returned to the caller, and printed out unless ``verbose=False`` is passed.
|
||||
|
||||
Args:
|
||||
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): print results to stdout
|
||||
|
||||
Returns:
|
||||
a tuple: ``(scores, params}``
|
||||
- ``scores``: a dict of scores data ``{'bleu': 39.6501, 'n_obs': 2000, 'runtime': 186, 'seconds_per_sample': 0.093}``
|
||||
- ``params``: a dict of custom params, e.g. ``{'num_beams': 5, 'length_penalty': 0.8}``
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
||||
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
|
||||
@@ -89,11 +109,19 @@ def run_generate():
|
||||
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
|
||||
)
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--dump-args", action="store_true", help="print the custom hparams with the results")
|
||||
parser.add_argument(
|
||||
"--info",
|
||||
nargs="?",
|
||||
type=str,
|
||||
const=datetime_now(),
|
||||
help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
|
||||
)
|
||||
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
|
||||
args, rest = parser.parse_known_args()
|
||||
parsed = parse_numeric_cl_kwargs(rest)
|
||||
if parsed:
|
||||
print(f"parsed the following generate kwargs: {parsed}")
|
||||
parsed_args = parse_numeric_n_bool_cl_kwargs(rest)
|
||||
if parsed_args and verbose:
|
||||
print(f"parsed the following generate kwargs: {parsed_args}")
|
||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||
if args.n_obs > 0:
|
||||
examples = examples[: args.n_obs]
|
||||
@@ -109,23 +137,35 @@ def run_generate():
|
||||
fp16=args.fp16,
|
||||
task=args.task,
|
||||
prefix=args.prefix,
|
||||
**parsed,
|
||||
**parsed_args,
|
||||
)
|
||||
|
||||
if args.reference_path is None:
|
||||
return
|
||||
return {}
|
||||
|
||||
# Compute scores
|
||||
score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
|
||||
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
||||
scores: dict = score_fn(output_lns, reference_lns)
|
||||
scores.update(runtime_metrics)
|
||||
print(scores)
|
||||
|
||||
if args.dump_args:
|
||||
scores.update(parsed_args)
|
||||
if args.info:
|
||||
scores["info"] = args.info
|
||||
|
||||
if verbose:
|
||||
print(*scores)
|
||||
|
||||
if args.score_path is not None:
|
||||
json.dump(scores, open(args.score_path, "w"))
|
||||
path = args.score_path
|
||||
json.dump(scores, open(path, "w"))
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Usage for MT:
|
||||
# python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
|
||||
run_generate()
|
||||
run_generate(verbose=True)
|
||||
|
||||
139
examples/seq2seq/run_eval_search.py
Normal file
139
examples/seq2seq/run_eval_search.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import operator
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
try:
|
||||
from .run_eval import datetime_now, run_generate
|
||||
except ImportError:
|
||||
from run_eval import datetime_now, run_generate
|
||||
|
||||
|
||||
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
|
||||
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
|
||||
task_score_names = {
|
||||
"translation": ["bleu"],
|
||||
"translation_en_to_de": ["bleu"],
|
||||
"summarization": ["rouge1", "rouge2", "rougeL"],
|
||||
}
|
||||
|
||||
|
||||
def parse_search_arg(search):
|
||||
groups = search.split()
|
||||
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
|
||||
entry_names = list(entries.keys())
|
||||
sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
|
||||
matrix = [list(x) for x in itertools.product(*sets)]
|
||||
return matrix, entry_names
|
||||
|
||||
|
||||
def run_search():
|
||||
"""
|
||||
Run parametric search over the desired hparam space with help of ``run_eval.py``.
|
||||
|
||||
All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args.
|
||||
|
||||
The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
|
||||
```
|
||||
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
|
||||
```
|
||||
which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with:
|
||||
|
||||
```
|
||||
--num_beams 5 --length_penalty 0.8 --early_stopping true
|
||||
--num_beams 5 --length_penalty 0.8 --early_stopping false
|
||||
[...]
|
||||
--num_beams 10 --length_penalty 1.2 --early_stopping false
|
||||
```
|
||||
|
||||
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
|
||||
|
||||
|
||||
"""
|
||||
prog = sys.argv[0]
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search",
|
||||
type=str,
|
||||
required=False,
|
||||
help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys()
|
||||
)
|
||||
parser.add_argument(
|
||||
"--info",
|
||||
nargs="?",
|
||||
type=str,
|
||||
const=datetime_now(),
|
||||
help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
|
||||
)
|
||||
args, args_main = parser.parse_known_args()
|
||||
# we share some of the args
|
||||
args_main.extend(["--task", args.task])
|
||||
args_normal = [prog] + args_main
|
||||
|
||||
matrix, col_names = parse_search_arg(args.search)
|
||||
col_names[0:0] = task_score_names[args.task] # score cols first
|
||||
col_widths = {col: len(str(col)) for col in col_names}
|
||||
results = []
|
||||
for r in matrix:
|
||||
hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)}
|
||||
args_exp = " ".join(r).split()
|
||||
args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM
|
||||
sys.argv = args_normal + args_exp
|
||||
|
||||
# XXX: need to trap CUDA OOM and lower args.bs if that happens and retry
|
||||
|
||||
scores = run_generate(verbose=False)
|
||||
# make sure scores are first in the table
|
||||
result = OrderedDict()
|
||||
for score in task_score_names[args.task]:
|
||||
result[score] = scores[score]
|
||||
result.update(hparams)
|
||||
results.append(result)
|
||||
|
||||
# find widest entries
|
||||
for k, v in result.items():
|
||||
l = len(str(v))
|
||||
if l > col_widths[k]:
|
||||
col_widths[k] = l
|
||||
|
||||
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True)
|
||||
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
|
||||
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
|
||||
for row in results_sorted:
|
||||
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
|
||||
|
||||
best = results_sorted[0]
|
||||
for score in task_score_names[args.task]:
|
||||
del best[score]
|
||||
best_args = [f"--{k} {v}" for k, v in best.items()]
|
||||
dyn_args = ["--bs", str(args.bs)]
|
||||
if args.info:
|
||||
print(f"\nInfo: {args.info}")
|
||||
print("\nBest score args:")
|
||||
print(" ".join(args_main + best_args + dyn_args))
|
||||
|
||||
return results_sorted
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Usage:
|
||||
# [normal-run_eval_search.py cmd plus] \
|
||||
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
|
||||
#
|
||||
# Example:
|
||||
# PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \
|
||||
# $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \
|
||||
# --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \
|
||||
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
|
||||
run_search()
|
||||
@@ -23,6 +23,7 @@ from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import SummarizationModule, main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .run_eval_search import run_search
|
||||
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
|
||||
@@ -283,7 +284,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
def test_run_eval(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
@@ -311,6 +312,58 @@ def test_run_eval(model):
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
@slow
|
||||
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
|
||||
def test_run_eval_search(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
|
||||
text = {
|
||||
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
|
||||
"de": [
|
||||
"Maschinelles Lernen ist großartig, oder?",
|
||||
"Ich esse gerne Bananen",
|
||||
"Morgen ist wieder ein toller Tag!",
|
||||
],
|
||||
}
|
||||
|
||||
tmp_dir = Path(tempfile.mkdtemp())
|
||||
score_path = str(tmp_dir / "scores.json")
|
||||
reference_path = str(tmp_dir / "val.target")
|
||||
_dump_articles(input_file_name, text["en"])
|
||||
_dump_articles(reference_path, text["de"])
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = [
|
||||
"run_eval_search.py",
|
||||
model,
|
||||
str(input_file_name),
|
||||
str(output_file_name),
|
||||
"--score_path",
|
||||
score_path,
|
||||
"--reference_path",
|
||||
reference_path,
|
||||
"--task",
|
||||
task,
|
||||
"--search",
|
||||
"num_beams=1:2 length_penalty=0.9:1.0",
|
||||
]
|
||||
with patch.object(sys, "argv", testargs):
|
||||
with CaptureStdout() as cs:
|
||||
run_search()
|
||||
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
|
||||
un_expected_strings = ["Info"]
|
||||
if "translation" in task:
|
||||
expected_strings.append("bleu")
|
||||
else:
|
||||
expected_strings.extend(["rouge1", "rouge2", "rougeL"])
|
||||
for w in expected_strings:
|
||||
assert w in cs.out
|
||||
for w in un_expected_strings:
|
||||
assert w not in cs.out
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"],
|
||||
|
||||
@@ -377,14 +377,22 @@ def assert_not_all_frozen(model):
|
||||
# CLI Parsing utils
|
||||
|
||||
|
||||
def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
|
||||
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
|
||||
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
|
||||
"""
|
||||
Parse an argv list of unspecified command line args to a dict.
|
||||
Assumes all values are either numeric or boolean in the form of true/false.
|
||||
"""
|
||||
result = {}
|
||||
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
|
||||
num_pairs = len(unparsed_args) // 2
|
||||
for pair_num in range(num_pairs):
|
||||
i = 2 * pair_num
|
||||
assert unparsed_args[i].startswith("--")
|
||||
if unparsed_args[i + 1].lower() == "true":
|
||||
value = True
|
||||
elif unparsed_args[i + 1].lower() == "false":
|
||||
value = False
|
||||
else:
|
||||
try:
|
||||
value = int(unparsed_args[i + 1])
|
||||
except ValueError:
|
||||
|
||||
Reference in New Issue
Block a user