47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
from pathlib import Path
|
|
|
|
import fire
|
|
|
|
|
|
try:
|
|
from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
|
|
except ImportError:
|
|
from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
|
|
|
|
|
|
def combine_partial_results(
|
|
result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False
|
|
):
|
|
"""Write first n lines of each file f in src_dir to dest_dir/f """
|
|
src_dir = Path(result_dir)
|
|
save_dir = Path(save_dir)
|
|
save_dir.mkdir(exist_ok=True)
|
|
paths_to_combine = list(src_dir.glob("rank*.json"))
|
|
records = []
|
|
for partial_result in paths_to_combine:
|
|
records.extend(load_json(partial_result))
|
|
preds = [x["pred"] for x in records]
|
|
labels = [x["label"] for x in records]
|
|
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
|
metrics = score_fn(preds, labels)
|
|
save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json
|
|
print(metrics)
|
|
if just_metrics:
|
|
return
|
|
|
|
if save_prefix is None:
|
|
save_prefix = "generated"
|
|
print("using generated as prefix")
|
|
|
|
tgt_path = save_dir.joinpath(f"{save_prefix}.target")
|
|
write_txt_file(labels, tgt_path)
|
|
pred_path = save_dir.joinpath(f"{save_prefix}.pred_target")
|
|
write_txt_file(preds, pred_path)
|
|
if "source" in records[0]:
|
|
src_path = save_dir.joinpath(f"{save_prefix}.source")
|
|
write_txt_file([x["source"] for x in records], src_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(combine_partial_results)
|