[s2s] tiny QOL improvement: run_eval prints scores (#6341)
This commit is contained in:
@@ -33,7 +33,6 @@ def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
|||||||
new_src, new_tgt = src, tgt
|
new_src, new_tgt = src, tgt
|
||||||
else: # can fit, keep adding
|
else: # can fit, keep adding
|
||||||
new_src, new_tgt = cand_src, cand_tgt
|
new_src, new_tgt = cand_src, cand_tgt
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
if new_src:
|
if new_src:
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ def run_generate():
|
|||||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||||
if args.n_obs > 0:
|
if args.n_obs > 0:
|
||||||
examples = examples[: args.n_obs]
|
examples = examples[: args.n_obs]
|
||||||
|
Path(args.save_path).parent.mkdir(exist_ok=True)
|
||||||
generate_summaries_or_translations(
|
generate_summaries_or_translations(
|
||||||
examples,
|
examples,
|
||||||
args.save_path,
|
args.save_path,
|
||||||
@@ -107,6 +107,7 @@ def run_generate():
|
|||||||
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
||||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
||||||
scores: dict = score_fn(output_lns, reference_lns)
|
scores: dict = score_fn(output_lns, reference_lns)
|
||||||
|
print(scores)
|
||||||
if args.score_path is not None:
|
if args.score_path is not None:
|
||||||
json.dump(scores, open(args.score_path, "w+"))
|
json.dump(scores, open(args.score_path, "w+"))
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
Reference in New Issue
Block a user