[s2s] run_eval/run_eval_search tweaks (#7192)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Stas Bekman
2020-09-17 11:26:38 -07:00
committed by GitHub
parent 9c5bcab5b0
commit efeab6a3f1
2 changed files with 58 additions and 48 deletions

View File

@@ -15,7 +15,6 @@ except ImportError:
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
"translation": ["bleu"],
"translation_en_to_de": ["bleu"],
"summarization": ["rouge1", "rouge2", "rougeL"],
}
@@ -66,9 +65,7 @@ def run_search():
parser.add_argument(
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
)
parser.add_argument(
"--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys()
)
parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
parser.add_argument(
"--info",
nargs="?",
@@ -81,8 +78,11 @@ def run_search():
args_main.extend(["--task", args.task])
args_normal = [prog] + args_main
# to support variations like translation_en_to_de"
task = "translation" if "translation" in args.task else "summarization"
matrix, col_names = parse_search_arg(args.search)
col_names[0:0] = task_score_names[args.task] # score cols first
col_names[0:0] = task_score_names[task] # score cols first
col_widths = {col: len(str(col)) for col in col_names}
results = []
for r in matrix:
@@ -96,7 +96,7 @@ def run_search():
scores = run_generate(verbose=False)
# make sure scores are first in the table
result = OrderedDict()
for score in task_score_names[args.task]:
for score in task_score_names[task]:
result[score] = scores[score]
result.update(hparams)
results.append(result)
@@ -107,14 +107,14 @@ def run_search():
if l > col_widths[k]:
col_widths[k] = l
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True)
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
for row in results_sorted:
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
best = results_sorted[0]
for score in task_score_names[args.task]:
for score in task_score_names[task]:
del best[score]
best_args = [f"--{k} {v}" for k, v in best.items()]
dyn_args = ["--bs", str(args.bs)]