integrate ROUGE
This commit is contained in:
committed by
Julien Chaumond
parent
076602bdc4
commit
ade3cdf5ad
@@ -21,9 +21,6 @@
|
||||
# SOFTWARE.
|
||||
import copy
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -1082,11 +1079,6 @@ class Translator(object):
|
||||
|
||||
return translations
|
||||
|
||||
def _report_rouge(self, gold_path, can_path):
|
||||
self.logger.info("Calculating Rouge")
|
||||
results_dict = test_rouge(self.args.temp_dir, can_path, gold_path)
|
||||
return results_dict
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
@@ -1113,63 +1105,10 @@ def tile(x, count, dim=0):
|
||||
|
||||
|
||||
#
|
||||
# All things ROUGE. Uses `pyrouge` which is a hot mess.
|
||||
# Optimizer for training. We keep this here in case we want to add
|
||||
# a finetuning script.
|
||||
#
|
||||
|
||||
|
||||
def test_rouge(temp_dir, cand, ref):
|
||||
candidates = [line.strip() for line in open(cand, encoding="utf-8")]
|
||||
references = [line.strip() for line in open(ref, encoding="utf-8")]
|
||||
print(len(candidates))
|
||||
print(len(references))
|
||||
assert len(candidates) == len(references)
|
||||
|
||||
cnt = len(candidates)
|
||||
current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||
tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time))
|
||||
if not os.path.isdir(tmp_dir):
|
||||
os.mkdir(tmp_dir)
|
||||
os.mkdir(tmp_dir + "/candidate")
|
||||
os.mkdir(tmp_dir + "/reference")
|
||||
try:
|
||||
|
||||
for i in range(cnt):
|
||||
if len(references[i]) < 1:
|
||||
continue
|
||||
with open(
|
||||
tmp_dir + "/candidate/cand.{}.txt".format(i), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(candidates[i])
|
||||
with open(
|
||||
tmp_dir + "/reference/ref.{}.txt".format(i), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(references[i])
|
||||
r = pyrouge.Rouge155(temp_dir=temp_dir)
|
||||
r.model_dir = tmp_dir + "/reference/"
|
||||
r.system_dir = tmp_dir + "/candidate/"
|
||||
r.model_filename_pattern = "ref.#ID#.txt"
|
||||
r.system_filename_pattern = r"cand.(\d+).txt"
|
||||
rouge_results = r.convert_and_evaluate()
|
||||
print(rouge_results)
|
||||
results_dict = r.output_to_dict(rouge_results)
|
||||
finally:
|
||||
pass
|
||||
if os.path.isdir(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
return results_dict
|
||||
|
||||
|
||||
def rouge_results_to_str(results_dict):
|
||||
return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
|
||||
results_dict["rouge_1_f_score"] * 100,
|
||||
results_dict["rouge_2_f_score"] * 100,
|
||||
results_dict["rouge_l_f_score"] * 100,
|
||||
results_dict["rouge_1_recall"] * 100,
|
||||
results_dict["rouge_2_recall"] * 100,
|
||||
results_dict["rouge_l_recall"] * 100,
|
||||
)
|
||||
|
||||
|
||||
class BertSumOptimizer(object):
|
||||
""" Specific optimizer for BertSum.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user