[s2s] tiny QOL improvement: run_eval prints scores (#6341)
This commit is contained in:
@@ -33,7 +33,6 @@ def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
||||
new_src, new_tgt = src, tgt
|
||||
else: # can fit, keep adding
|
||||
new_src, new_tgt = cand_src, cand_tgt
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
# cleanup
|
||||
if new_src:
|
||||
|
||||
@@ -89,7 +89,7 @@ def run_generate():
|
||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||
if args.n_obs > 0:
|
||||
examples = examples[: args.n_obs]
|
||||
|
||||
Path(args.save_path).parent.mkdir(exist_ok=True)
|
||||
generate_summaries_or_translations(
|
||||
examples,
|
||||
args.save_path,
|
||||
@@ -107,6 +107,7 @@ def run_generate():
|
||||
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
||||
scores: dict = score_fn(output_lns, reference_lns)
|
||||
print(scores)
|
||||
if args.score_path is not None:
|
||||
json.dump(scores, open(args.score_path, "w+"))
|
||||
return scores
|
||||
|
||||
Reference in New Issue
Block a user