Files
HuggingFace_transformer/examples/seq2seq/aggregate_distributed_results.py
2020-09-13 17:28:18 -04:00

47 lines
1.6 KiB
Python

from pathlib import Path
import fire
try:
from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
except ImportError:
from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
def combine_partial_results(
result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False
):
"""Write first n lines of each file f in src_dir to dest_dir/f """
src_dir = Path(result_dir)
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
paths_to_combine = list(src_dir.glob("rank*.json"))
records = []
for partial_result in paths_to_combine:
records.extend(load_json(partial_result))
preds = [x["pred"] for x in records]
labels = [x["label"] for x in records]
score_fn = calculate_bleu if calc_bleu else calculate_rouge
metrics = score_fn(preds, labels)
save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json
print(metrics)
if just_metrics:
return
if save_prefix is None:
save_prefix = "generated"
print("using generated as prefix")
tgt_path = save_dir.joinpath(f"{save_prefix}.target")
write_txt_file(labels, tgt_path)
pred_path = save_dir.joinpath(f"{save_prefix}.pred_target")
write_txt_file(preds, pred_path)
if "source" in records[0]:
src_path = save_dir.joinpath(f"{save_prefix}.source")
write_txt_file([x["source"] for x in records], src_path)
if __name__ == "__main__":
fire.Fire(combine_partial_results)