[s2s] rougeLSum expects \n between sentences (#7410)

Co-authored-by: Swetha Mandava <smandava@nvidia.com>
This commit is contained in:
Sam Shleifer
2020-09-27 16:27:19 -04:00
committed by GitHub
parent eab5f59682
commit 7296fea1d6
7 changed files with 176 additions and 14 deletions

View File

@@ -18,6 +18,7 @@ from sacrebleu import corpus_bleu
from torch import nn
from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer
from transformers.file_utils import cached_property
@@ -378,19 +379,63 @@ def get_git_info():
return repo_infos
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
def extract_rouge_mid_statistics(dct):
new_dict = {}
for k1, v1 in dct.items():
mid = v1.mid
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
return new_dict
def calculate_rouge(
pred_lns: List[str],
tgt_lns: List[str],
use_stemmer=True,
rouge_keys=ROUGE_KEYS,
return_precision_and_recall=False,
bootstrap_aggregation=True,
newline_sep=True,
) -> Dict:
"""Calculate rouge using rouge_scorer package.
Args:
pred_lns: list of summaries generated by model
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
use_stemmer: Bool indicating whether Porter stemmer should be used to
strip word suffixes to improve matching.
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
return_precision_and_recall: (False) whether to also return precision and recall.
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
on multi sentence summaries (CNN/DM dataset).
Returns:
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
"""
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for reference_ln, output_ln in zip(reference_lns, output_lns):
scores = scorer.score(reference_ln, output_ln)
for pred, tgt in zip(tgt_lns, pred_lns):
# rougeLsum expects "\n" separated sentences within a summary
if newline_sep:
pred = add_newline_to_end_of_each_sentence(pred)
tgt = add_newline_to_end_of_each_sentence(tgt)
scores = scorer.score(pred, tgt)
aggregator.add_scores(scores)
result = aggregator.aggregate()
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
if bootstrap_aggregation:
result = aggregator.aggregate()
if return_precision_and_recall:
return extract_rouge_mid_statistics(result) # here we return dict
else:
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
else:
return aggregator._scores # here we return defaultdict(list)
# Utilities for freezing parameters and checking whether they are frozen
@@ -423,9 +468,6 @@ def assert_not_all_frozen(model):
assert any(model_grads), f"none of {npars} weights require grad"
# CLI Parsing utils
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
"""
Parse an argv list of unspecified command line args to a dict.