[s2s] run_eval/run_eval_search tweaks (#7192)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -15,7 +15,6 @@ except ImportError:
|
||||
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
|
||||
task_score_names = {
|
||||
"translation": ["bleu"],
|
||||
"translation_en_to_de": ["bleu"],
|
||||
"summarization": ["rouge1", "rouge2", "rougeL"],
|
||||
}
|
||||
|
||||
@@ -66,9 +65,7 @@ def run_search():
|
||||
parser.add_argument(
|
||||
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys()
|
||||
)
|
||||
parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
|
||||
parser.add_argument(
|
||||
"--info",
|
||||
nargs="?",
|
||||
@@ -81,8 +78,11 @@ def run_search():
|
||||
args_main.extend(["--task", args.task])
|
||||
args_normal = [prog] + args_main
|
||||
|
||||
# to support variations like translation_en_to_de"
|
||||
task = "translation" if "translation" in args.task else "summarization"
|
||||
|
||||
matrix, col_names = parse_search_arg(args.search)
|
||||
col_names[0:0] = task_score_names[args.task] # score cols first
|
||||
col_names[0:0] = task_score_names[task] # score cols first
|
||||
col_widths = {col: len(str(col)) for col in col_names}
|
||||
results = []
|
||||
for r in matrix:
|
||||
@@ -96,7 +96,7 @@ def run_search():
|
||||
scores = run_generate(verbose=False)
|
||||
# make sure scores are first in the table
|
||||
result = OrderedDict()
|
||||
for score in task_score_names[args.task]:
|
||||
for score in task_score_names[task]:
|
||||
result[score] = scores[score]
|
||||
result.update(hparams)
|
||||
results.append(result)
|
||||
@@ -107,14 +107,14 @@ def run_search():
|
||||
if l > col_widths[k]:
|
||||
col_widths[k] = l
|
||||
|
||||
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True)
|
||||
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
|
||||
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
|
||||
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
|
||||
for row in results_sorted:
|
||||
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
|
||||
|
||||
best = results_sorted[0]
|
||||
for score in task_score_names[args.task]:
|
||||
for score in task_score_names[task]:
|
||||
del best[score]
|
||||
best_args = [f"--{k} {v}" for k, v in best.items()]
|
||||
dyn_args = ["--bs", str(args.bs)]
|
||||
|
||||
Reference in New Issue
Block a user