[s2s] rougeLSum expects \n between sentences (#7410)
Co-authored-by: Swetha Mandava <smandava@nvidia.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from sacrebleu import corpus_bleu
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from transformers import BartTokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
|
||||
@@ -378,19 +379,63 @@ def get_git_info():
|
||||
return repo_infos
|
||||
|
||||
|
||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
||||
|
||||
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
|
||||
def extract_rouge_mid_statistics(dct):
|
||||
new_dict = {}
|
||||
for k1, v1 in dct.items():
|
||||
mid = v1.mid
|
||||
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
|
||||
return new_dict
|
||||
|
||||
|
||||
def calculate_rouge(
|
||||
pred_lns: List[str],
|
||||
tgt_lns: List[str],
|
||||
use_stemmer=True,
|
||||
rouge_keys=ROUGE_KEYS,
|
||||
return_precision_and_recall=False,
|
||||
bootstrap_aggregation=True,
|
||||
newline_sep=True,
|
||||
) -> Dict:
|
||||
"""Calculate rouge using rouge_scorer package.
|
||||
|
||||
Args:
|
||||
pred_lns: list of summaries generated by model
|
||||
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
|
||||
use_stemmer: Bool indicating whether Porter stemmer should be used to
|
||||
strip word suffixes to improve matching.
|
||||
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
|
||||
return_precision_and_recall: (False) whether to also return precision and recall.
|
||||
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
|
||||
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
|
||||
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
|
||||
on multi sentence summaries (CNN/DM dataset).
|
||||
|
||||
Returns:
|
||||
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
|
||||
|
||||
"""
|
||||
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
scores = scorer.score(reference_ln, output_ln)
|
||||
for pred, tgt in zip(tgt_lns, pred_lns):
|
||||
# rougeLsum expects "\n" separated sentences within a summary
|
||||
if newline_sep:
|
||||
pred = add_newline_to_end_of_each_sentence(pred)
|
||||
tgt = add_newline_to_end_of_each_sentence(tgt)
|
||||
scores = scorer.score(pred, tgt)
|
||||
aggregator.add_scores(scores)
|
||||
|
||||
result = aggregator.aggregate()
|
||||
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
||||
if bootstrap_aggregation:
|
||||
result = aggregator.aggregate()
|
||||
if return_precision_and_recall:
|
||||
return extract_rouge_mid_statistics(result) # here we return dict
|
||||
else:
|
||||
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
||||
|
||||
else:
|
||||
return aggregator._scores # here we return defaultdict(list)
|
||||
|
||||
|
||||
# Utilities for freezing parameters and checking whether they are frozen
|
||||
@@ -423,9 +468,6 @@ def assert_not_all_frozen(model):
|
||||
assert any(model_grads), f"none of {npars} weights require grad"
|
||||
|
||||
|
||||
# CLI Parsing utils
|
||||
|
||||
|
||||
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
|
||||
"""
|
||||
Parse an argv list of unspecified command line args to a dict.
|
||||
|
||||
Reference in New Issue
Block a user