From 7296fea1d689f47de69fd45e438e42d65ca5a393 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 27 Sep 2020 16:27:19 -0400 Subject: [PATCH] [s2s] rougeLSum expects \n between sentences (#7410) Co-authored-by: Swetha Mandava --- examples/requirements.txt | 1 + examples/seq2seq/rouge_cli.py | 17 +++++ examples/seq2seq/run_eval_search.py | 3 +- examples/seq2seq/sentence_splitter.py | 21 ++++++ examples/seq2seq/test_calculate_rouge.py | 80 +++++++++++++++++++++++ examples/seq2seq/test_seq2seq_examples.py | 4 +- examples/seq2seq/utils.py | 64 ++++++++++++++---- 7 files changed, 176 insertions(+), 14 deletions(-) create mode 100644 examples/seq2seq/rouge_cli.py create mode 100644 examples/seq2seq/sentence_splitter.py create mode 100644 examples/seq2seq/test_calculate_rouge.py diff --git a/examples/requirements.txt b/examples/requirements.txt index 9b4433151c..f080459723 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -11,6 +11,7 @@ git-python==1.0.3 faiss-cpu streamlit elasticsearch +nltk pandas datasets fire diff --git a/examples/seq2seq/rouge_cli.py b/examples/seq2seq/rouge_cli.py new file mode 100644 index 0000000000..b193581bc8 --- /dev/null +++ b/examples/seq2seq/rouge_cli.py @@ -0,0 +1,17 @@ +import fire + +from utils import calculate_rouge, save_json + + +def calculate_rouge_path(pred_path, tgt_path, save_path=None, **kwargs): + """Kwargs will be passed to calculate_rouge""" + pred_lns = [x.strip() for x in open(pred_path).readlines()] + tgt_lns = [x.strip() for x in open(tgt_path).readlines()][: len(pred_lns)] + metrics = calculate_rouge(pred_lns, tgt_lns, **kwargs) + if save_path is not None: + save_json(metrics, save_path) + return metrics # these print nicely + + +if __name__ == "__main__": + fire.Fire(calculate_rouge_path) diff --git a/examples/seq2seq/run_eval_search.py b/examples/seq2seq/run_eval_search.py index 292918c9f3..8052b921d3 100755 --- a/examples/seq2seq/run_eval_search.py +++ b/examples/seq2seq/run_eval_search.py @@ -7,13 +7,14 @@ import sys from collections import OrderedDict from run_eval import datetime_now, run_generate +from utils import ROUGE_KEYS # A table of supported tasks and the list of scores in the order of importance to be sorted by. # To add a new task, simply list the score names that `run_eval.run_generate()` returns task_score_names = { "translation": ["bleu"], - "summarization": ["rouge1", "rouge2", "rougeL"], + "summarization": ROUGE_KEYS, } diff --git a/examples/seq2seq/sentence_splitter.py b/examples/seq2seq/sentence_splitter.py new file mode 100644 index 0000000000..197c4b250b --- /dev/null +++ b/examples/seq2seq/sentence_splitter.py @@ -0,0 +1,21 @@ +import re + + +try: + import nltk + + NLTK_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + NLTK_AVAILABLE = False + +if NLTK_AVAILABLE: + try: + nltk.download("punkt", quiet=True) + except FileExistsError: # multiprocessing race condition + pass + + +def add_newline_to_end_of_each_sentence(x: str) -> str: + re.sub("", "", x) # remove pegasus newline char + assert NLTK_AVAILABLE, "nltk must be installed to separate newlines betwee sentences. (pip install nltk)" + return "\n".join(nltk.sent_tokenize(x)) diff --git a/examples/seq2seq/test_calculate_rouge.py b/examples/seq2seq/test_calculate_rouge.py new file mode 100644 index 0000000000..bfa35adf11 --- /dev/null +++ b/examples/seq2seq/test_calculate_rouge.py @@ -0,0 +1,80 @@ +from collections import defaultdict +from pathlib import Path + +import pandas as pd + +from rouge_cli import calculate_rouge_path +from utils import calculate_rouge + + +PRED = [ + 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe depression" German airline confirms it knew of Andreas Lubitz\'s depression years before he took control.', + "The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the body.", + "Amnesty International releases its annual report on the death penalty. The report catalogs the use of state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital punishment.", +] + +TGT = [ + 'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .', + "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June . Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .", + "Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to death . Organization claims that governments around the world are using the threat of terrorism to advance executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death sentences up by 28% .", +] + + +def test_disaggregated_scores_are_determinstic(): + no_aggregation = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2", "rougeL"]) + assert isinstance(no_aggregation, defaultdict) + no_aggregation_just_r2 = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2"]) + assert ( + pd.DataFrame(no_aggregation["rouge2"]).fmeasure.mean() + == pd.DataFrame(no_aggregation_just_r2["rouge2"]).fmeasure.mean() + ) + + +def test_newline_cnn_improvement(): + k = "rougeLsum" + score = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=[k])[k] + score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=[k])[k] + assert score > score_no_sep + + +def test_newline_irrelevant_for_other_metrics(): + k = ["rouge1", "rouge2", "rougeL"] + score_sep = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=k) + score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=k) + assert score_sep == score_no_sep + + +def test_single_sent_scores_dont_depend_on_newline_sep(): + pred = [ + "Her older sister, Margot Frank, died in 1945, a month earlier than previously thought.", + 'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .', + ] + tgt = [ + "Margot Frank, died in 1945, a month earlier than previously thought.", + 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525.', + ] + assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False) + + +def test_pegasus_newline(): + + pred = [ + """" "a person who has such a video needs to immediately give it to the investigators," prosecutor says . "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """ + ] + tgt = [ + """ Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .""" + ] + + prev_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"], newline_sep=False)["rougeLsum"] + new_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"])["rougeLsum"] + assert new_score > prev_score + + +def test_rouge_cli(): + data_dir = Path("examples/seq2seq/test_data/wmt_en_ro") + metrics = calculate_rouge_path(data_dir.joinpath("test.source"), data_dir.joinpath("test.target")) + assert isinstance(metrics, dict) + metrics_default_dict = calculate_rouge_path( + data_dir.joinpath("test.source"), data_dir.joinpath("test.target"), bootstrap_aggregation=False + ) + assert isinstance(metrics_default_dict, defaultdict) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index a6fa8174d7..3e054649bc 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -20,7 +20,7 @@ from run_eval_search import run_search from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers.hf_api import HfApi from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow -from utils import label_smoothed_nll_loss, lmap, load_json +from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json logging.basicConfig(level=logging.DEBUG) @@ -365,7 +365,7 @@ def test_run_eval_search(model): if "translation" in task: expected_strings.append("bleu") else: - expected_strings.extend(["rouge1", "rouge2", "rougeL"]) + expected_strings.extend(ROUGE_KEYS) for w in expected_strings: assert w in cs.out for w in un_expected_strings: diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index cf5d778792..ac1629c0c5 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -18,6 +18,7 @@ from sacrebleu import corpus_bleu from torch import nn from torch.utils.data import Dataset, Sampler +from sentence_splitter import add_newline_to_end_of_each_sentence from transformers import BartTokenizer from transformers.file_utils import cached_property @@ -378,19 +379,63 @@ def get_git_info(): return repo_infos -ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] +ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] -def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: - scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) +def extract_rouge_mid_statistics(dct): + new_dict = {} + for k1, v1 in dct.items(): + mid = v1.mid + new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]} + return new_dict + + +def calculate_rouge( + pred_lns: List[str], + tgt_lns: List[str], + use_stemmer=True, + rouge_keys=ROUGE_KEYS, + return_precision_and_recall=False, + bootstrap_aggregation=True, + newline_sep=True, +) -> Dict: + """Calculate rouge using rouge_scorer package. + + Args: + pred_lns: list of summaries generated by model + tgt_lns: list of groundtruth summaries (e.g. contents of val.target) + use_stemmer: Bool indicating whether Porter stemmer should be used to + strip word suffixes to improve matching. + rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum + return_precision_and_recall: (False) whether to also return precision and recall. + bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False + this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]`` + newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL + on multi sentence summaries (CNN/DM dataset). + + Returns: + Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys + + """ + scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer) aggregator = scoring.BootstrapAggregator() - - for reference_ln, output_ln in zip(reference_lns, output_lns): - scores = scorer.score(reference_ln, output_ln) + for pred, tgt in zip(tgt_lns, pred_lns): + # rougeLsum expects "\n" separated sentences within a summary + if newline_sep: + pred = add_newline_to_end_of_each_sentence(pred) + tgt = add_newline_to_end_of_each_sentence(tgt) + scores = scorer.score(pred, tgt) aggregator.add_scores(scores) - result = aggregator.aggregate() - return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} + if bootstrap_aggregation: + result = aggregator.aggregate() + if return_precision_and_recall: + return extract_rouge_mid_statistics(result) # here we return dict + else: + return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} + + else: + return aggregator._scores # here we return defaultdict(list) # Utilities for freezing parameters and checking whether they are frozen @@ -423,9 +468,6 @@ def assert_not_all_frozen(model): assert any(model_grads), f"none of {npars} weights require grad" -# CLI Parsing utils - - def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]: """ Parse an argv list of unspecified command line args to a dict.