integrate ROUGE
This commit is contained in:
committed by
Julien Chaumond
parent
076602bdc4
commit
ade3cdf5ad
@@ -21,9 +21,6 @@
|
|||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -1082,11 +1079,6 @@ class Translator(object):
|
|||||||
|
|
||||||
return translations
|
return translations
|
||||||
|
|
||||||
def _report_rouge(self, gold_path, can_path):
|
|
||||||
self.logger.info("Calculating Rouge")
|
|
||||||
results_dict = test_rouge(self.args.temp_dir, can_path, gold_path)
|
|
||||||
return results_dict
|
|
||||||
|
|
||||||
|
|
||||||
def tile(x, count, dim=0):
|
def tile(x, count, dim=0):
|
||||||
"""
|
"""
|
||||||
@@ -1113,63 +1105,10 @@ def tile(x, count, dim=0):
|
|||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# All things ROUGE. Uses `pyrouge` which is a hot mess.
|
# Optimizer for training. We keep this here in case we want to add
|
||||||
|
# a finetuning script.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def test_rouge(temp_dir, cand, ref):
|
|
||||||
candidates = [line.strip() for line in open(cand, encoding="utf-8")]
|
|
||||||
references = [line.strip() for line in open(ref, encoding="utf-8")]
|
|
||||||
print(len(candidates))
|
|
||||||
print(len(references))
|
|
||||||
assert len(candidates) == len(references)
|
|
||||||
|
|
||||||
cnt = len(candidates)
|
|
||||||
current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
|
||||||
tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time))
|
|
||||||
if not os.path.isdir(tmp_dir):
|
|
||||||
os.mkdir(tmp_dir)
|
|
||||||
os.mkdir(tmp_dir + "/candidate")
|
|
||||||
os.mkdir(tmp_dir + "/reference")
|
|
||||||
try:
|
|
||||||
|
|
||||||
for i in range(cnt):
|
|
||||||
if len(references[i]) < 1:
|
|
||||||
continue
|
|
||||||
with open(
|
|
||||||
tmp_dir + "/candidate/cand.{}.txt".format(i), "w", encoding="utf-8"
|
|
||||||
) as f:
|
|
||||||
f.write(candidates[i])
|
|
||||||
with open(
|
|
||||||
tmp_dir + "/reference/ref.{}.txt".format(i), "w", encoding="utf-8"
|
|
||||||
) as f:
|
|
||||||
f.write(references[i])
|
|
||||||
r = pyrouge.Rouge155(temp_dir=temp_dir)
|
|
||||||
r.model_dir = tmp_dir + "/reference/"
|
|
||||||
r.system_dir = tmp_dir + "/candidate/"
|
|
||||||
r.model_filename_pattern = "ref.#ID#.txt"
|
|
||||||
r.system_filename_pattern = r"cand.(\d+).txt"
|
|
||||||
rouge_results = r.convert_and_evaluate()
|
|
||||||
print(rouge_results)
|
|
||||||
results_dict = r.output_to_dict(rouge_results)
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
if os.path.isdir(tmp_dir):
|
|
||||||
shutil.rmtree(tmp_dir)
|
|
||||||
return results_dict
|
|
||||||
|
|
||||||
|
|
||||||
def rouge_results_to_str(results_dict):
|
|
||||||
return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
|
|
||||||
results_dict["rouge_1_f_score"] * 100,
|
|
||||||
results_dict["rouge_2_f_score"] * 100,
|
|
||||||
results_dict["rouge_l_f_score"] * 100,
|
|
||||||
results_dict["rouge_1_recall"] * 100,
|
|
||||||
results_dict["rouge_2_recall"] * 100,
|
|
||||||
results_dict["rouge_l_recall"] * 100,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BertSumOptimizer(object):
|
class BertSumOptimizer(object):
|
||||||
""" Specific optimizer for BertSum.
|
""" Specific optimizer for BertSum.
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,26 @@ def evaluate(args):
|
|||||||
"PAD": tokenizer.vocab["[PAD]"],
|
"PAD": tokenizer.vocab["[PAD]"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.compute_rouge:
|
||||||
|
reference_summaries = []
|
||||||
|
generated_summaries = []
|
||||||
|
|
||||||
|
import rouge
|
||||||
|
import nltk
|
||||||
|
nltk.download('punkt')
|
||||||
|
rouge_evaluator = rouge.Rouge(
|
||||||
|
metrics=['rouge-n', 'rouge-l'],
|
||||||
|
max_n=2,
|
||||||
|
limit_length=True,
|
||||||
|
length_limit=args.beam_size,
|
||||||
|
length_limit_type='words',
|
||||||
|
apply_avg=True,
|
||||||
|
apply_best=False,
|
||||||
|
alpha=0.5, # Default F1_score
|
||||||
|
weight_factor=1.2,
|
||||||
|
stemming=True,
|
||||||
|
)
|
||||||
|
|
||||||
# these (unused) arguments are defined to keep the compatibility
|
# these (unused) arguments are defined to keep the compatibility
|
||||||
# with the legacy code and will be deleted in a next iteration.
|
# with the legacy code and will be deleted in a next iteration.
|
||||||
args.result_path = ""
|
args.result_path = ""
|
||||||
@@ -66,6 +86,16 @@ def evaluate(args):
|
|||||||
summaries = [format_summary(t) for t in translations]
|
summaries = [format_summary(t) for t in translations]
|
||||||
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
||||||
|
|
||||||
|
if args.compute_rouge:
|
||||||
|
reference_summaries += batch.tgt_str
|
||||||
|
generated_summaries += summaries
|
||||||
|
|
||||||
|
if args.compute_rouge:
|
||||||
|
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
|
||||||
|
str_scores = format_rouge_scores(scores)
|
||||||
|
save_rouge_scores(str_scores)
|
||||||
|
print(str_scores)
|
||||||
|
|
||||||
|
|
||||||
def format_summary(translation):
|
def format_summary(translation):
|
||||||
""" Transforms the output of the `from_batch` function
|
""" Transforms the output of the `from_batch` function
|
||||||
@@ -86,6 +116,41 @@ def format_summary(translation):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def format_rouge_scores(scores):
|
||||||
|
return """\n
|
||||||
|
****** ROUGE SCORES ******
|
||||||
|
|
||||||
|
** ROUGE 1
|
||||||
|
F1 >> {:.3f}
|
||||||
|
Precision >> {:.3f}
|
||||||
|
Recall >> {:.3f}
|
||||||
|
|
||||||
|
** ROUGE 2
|
||||||
|
F1 >> {:.3f}
|
||||||
|
Precision >> {:.3f}
|
||||||
|
Recall >> {:.3f}
|
||||||
|
|
||||||
|
** ROUGE L
|
||||||
|
F1 >> {:.3f}
|
||||||
|
Precision >> {:.3f}
|
||||||
|
Recall >> {:.3f}""".format(
|
||||||
|
scores['rouge-1']['f'],
|
||||||
|
scores['rouge-1']['p'],
|
||||||
|
scores['rouge-1']['r'],
|
||||||
|
scores['rouge-2']['f'],
|
||||||
|
scores['rouge-2']['p'],
|
||||||
|
scores['rouge-2']['r'],
|
||||||
|
scores['rouge-l']['f'],
|
||||||
|
scores['rouge-l']['p'],
|
||||||
|
scores['rouge-l']['r'],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_rouge_scores(str_scores):
|
||||||
|
with open("rouge_scores.txt", "w") as output:
|
||||||
|
output.write(str_scores)
|
||||||
|
|
||||||
|
|
||||||
def save_summaries(summaries, path, original_document_name):
|
def save_summaries(summaries, path, original_document_name):
|
||||||
""" Write the summaries in fies that are prefixed by the original
|
""" Write the summaries in fies that are prefixed by the original
|
||||||
files' name with the `_summary` appended.
|
files' name with the `_summary` appended.
|
||||||
@@ -142,26 +207,27 @@ def collate(data, tokenizer, block_size):
|
|||||||
"""
|
"""
|
||||||
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
||||||
names = [name for name, _, _ in data]
|
names = [name for name, _, _ in data]
|
||||||
|
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||||
|
|
||||||
encoded_text = [
|
encoded_text = [
|
||||||
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
||||||
]
|
]
|
||||||
stories = torch.tensor(
|
encoded_stories = torch.tensor(
|
||||||
[
|
[
|
||||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
||||||
for story, _ in encoded_text
|
for story, _ in encoded_text
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
|
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||||
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
|
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||||
|
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
document_names=names,
|
document_names=names,
|
||||||
batch_size=len(stories),
|
batch_size=len(encoded_stories),
|
||||||
src=stories,
|
src=encoded_stories,
|
||||||
segs=encoder_token_type_ids,
|
segs=encoder_token_type_ids,
|
||||||
mask_src=encoder_mask,
|
mask_src=encoder_mask,
|
||||||
tgt_str=[""] * len(stories),
|
tgt_str=summaries,
|
||||||
)
|
)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
@@ -196,6 +262,13 @@ def main():
|
|||||||
required=False,
|
required=False,
|
||||||
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
|
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--compute_rouge",
|
||||||
|
default=False,
|
||||||
|
type=bool,
|
||||||
|
required=False,
|
||||||
|
help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.",
|
||||||
|
)
|
||||||
# EVALUATION options
|
# EVALUATION options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--visible_gpus",
|
"--visible_gpus",
|
||||||
|
|||||||
@@ -11,4 +11,5 @@ sentencepiece
|
|||||||
# For XLM
|
# For XLM
|
||||||
sacremoses
|
sacremoses
|
||||||
# For ROUGE
|
# For ROUGE
|
||||||
|
nltk
|
||||||
py-rouge
|
py-rouge
|
||||||
|
|||||||
Reference in New Issue
Block a user