From ade3cdf5adfcff7736b326b1360fcf2b59aae47e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 6 Dec 2019 11:36:44 +0100 Subject: [PATCH] integrate ROUGE --- examples/summarization/modeling_bertabs.py | 65 +--------------- examples/summarization/run_summarization.py | 85 +++++++++++++++++++-- requirements.txt | 1 + 3 files changed, 82 insertions(+), 69 deletions(-) diff --git a/examples/summarization/modeling_bertabs.py b/examples/summarization/modeling_bertabs.py index efca33fb56..57126a4df3 100644 --- a/examples/summarization/modeling_bertabs.py +++ b/examples/summarization/modeling_bertabs.py @@ -21,9 +21,6 @@ # SOFTWARE. import copy import math -import shutil -import time -import os import numpy as np import torch @@ -1082,11 +1079,6 @@ class Translator(object): return translations - def _report_rouge(self, gold_path, can_path): - self.logger.info("Calculating Rouge") - results_dict = test_rouge(self.args.temp_dir, can_path, gold_path) - return results_dict - def tile(x, count, dim=0): """ @@ -1113,63 +1105,10 @@ def tile(x, count, dim=0): # -# All things ROUGE. Uses `pyrouge` which is a hot mess. +# Optimizer for training. We keep this here in case we want to add +# a finetuning script. # - -def test_rouge(temp_dir, cand, ref): - candidates = [line.strip() for line in open(cand, encoding="utf-8")] - references = [line.strip() for line in open(ref, encoding="utf-8")] - print(len(candidates)) - print(len(references)) - assert len(candidates) == len(references) - - cnt = len(candidates) - current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) - if not os.path.isdir(tmp_dir): - os.mkdir(tmp_dir) - os.mkdir(tmp_dir + "/candidate") - os.mkdir(tmp_dir + "/reference") - try: - - for i in range(cnt): - if len(references[i]) < 1: - continue - with open( - tmp_dir + "/candidate/cand.{}.txt".format(i), "w", encoding="utf-8" - ) as f: - f.write(candidates[i]) - with open( - tmp_dir + "/reference/ref.{}.txt".format(i), "w", encoding="utf-8" - ) as f: - f.write(references[i]) - r = pyrouge.Rouge155(temp_dir=temp_dir) - r.model_dir = tmp_dir + "/reference/" - r.system_dir = tmp_dir + "/candidate/" - r.model_filename_pattern = "ref.#ID#.txt" - r.system_filename_pattern = r"cand.(\d+).txt" - rouge_results = r.convert_and_evaluate() - print(rouge_results) - results_dict = r.output_to_dict(rouge_results) - finally: - pass - if os.path.isdir(tmp_dir): - shutil.rmtree(tmp_dir) - return results_dict - - -def rouge_results_to_str(results_dict): - return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( - results_dict["rouge_1_f_score"] * 100, - results_dict["rouge_2_f_score"] * 100, - results_dict["rouge_l_f_score"] * 100, - results_dict["rouge_1_recall"] * 100, - results_dict["rouge_2_recall"] * 100, - results_dict["rouge_l_recall"] * 100, - ) - - class BertSumOptimizer(object): """ Specific optimizer for BertSum. diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py index ed663e880b..a9d08aca82 100644 --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -41,6 +41,26 @@ def evaluate(args): "PAD": tokenizer.vocab["[PAD]"], } + if args.compute_rouge: + reference_summaries = [] + generated_summaries = [] + + import rouge + import nltk + nltk.download('punkt') + rouge_evaluator = rouge.Rouge( + metrics=['rouge-n', 'rouge-l'], + max_n=2, + limit_length=True, + length_limit=args.beam_size, + length_limit_type='words', + apply_avg=True, + apply_best=False, + alpha=0.5, # Default F1_score + weight_factor=1.2, + stemming=True, + ) + # these (unused) arguments are defined to keep the compatibility # with the legacy code and will be deleted in a next iteration. args.result_path = "" @@ -66,6 +86,16 @@ def evaluate(args): summaries = [format_summary(t) for t in translations] save_summaries(summaries, args.summaries_output_dir, batch.document_names) + if args.compute_rouge: + reference_summaries += batch.tgt_str + generated_summaries += summaries + + if args.compute_rouge: + scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries) + str_scores = format_rouge_scores(scores) + save_rouge_scores(str_scores) + print(str_scores) + def format_summary(translation): """ Transforms the output of the `from_batch` function @@ -86,6 +116,41 @@ def format_summary(translation): return summary +def format_rouge_scores(scores): + return """\n +****** ROUGE SCORES ****** + +** ROUGE 1 +F1 >> {:.3f} +Precision >> {:.3f} +Recall >> {:.3f} + +** ROUGE 2 +F1 >> {:.3f} +Precision >> {:.3f} +Recall >> {:.3f} + +** ROUGE L +F1 >> {:.3f} +Precision >> {:.3f} +Recall >> {:.3f}""".format( + scores['rouge-1']['f'], + scores['rouge-1']['p'], + scores['rouge-1']['r'], + scores['rouge-2']['f'], + scores['rouge-2']['p'], + scores['rouge-2']['r'], + scores['rouge-l']['f'], + scores['rouge-l']['p'], + scores['rouge-l']['r'], + ) + + +def save_rouge_scores(str_scores): + with open("rouge_scores.txt", "w") as output: + output.write(str_scores) + + def save_summaries(summaries, path, original_document_name): """ Write the summaries in fies that are prefixed by the original files' name with the `_summary` appended. @@ -142,26 +207,27 @@ def collate(data, tokenizer, block_size): """ data = [x for x in data if not len(x[1]) == 0] # remove empty_files names = [name for name, _, _ in data] + summaries = [" ".join(summary_list) for _, _, summary_list in data] encoded_text = [ encode_for_summarization(story, summary, tokenizer) for _, story, summary in data ] - stories = torch.tensor( + encoded_stories = torch.tensor( [ fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text ] ) - encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id) - encoder_mask = build_mask(stories, tokenizer.pad_token_id) + encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id) + encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id) batch = Batch( document_names=names, - batch_size=len(stories), - src=stories, + batch_size=len(encoded_stories), + src=encoded_stories, segs=encoder_token_type_ids, mask_src=encoder_mask, - tgt_str=[""] * len(stories), + tgt_str=summaries, ) return batch @@ -196,6 +262,13 @@ def main(): required=False, help="The folder in wich the summaries should be written. Defaults to the folder where the documents are", ) + parser.add_argument( + "--compute_rouge", + default=False, + type=bool, + required=False, + help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.", + ) # EVALUATION options parser.add_argument( "--visible_gpus", diff --git a/requirements.txt b/requirements.txt index 236ac1c430..2cbcc3809d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ sentencepiece # For XLM sacremoses # For ROUGE +nltk py-rouge