From 0344428f7955675847ef95ddcb4980236b6f8721 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 25 Aug 2020 00:33:11 -0400 Subject: [PATCH] [s2s] round bleu, rouge to 4 digits (#6704) --- examples/seq2seq/distillation.py | 6 +++--- examples/seq2seq/finetune.py | 6 +++--- examples/seq2seq/run_eval.py | 6 +++--- examples/seq2seq/utils.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 2d5c719474..82194981a0 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -20,7 +20,7 @@ try: from .utils import ( any_requires_grad, assert_all_frozen, - calculate_bleu_score, + calculate_bleu, freeze_params, pickle_load, use_task_specific_params, @@ -32,7 +32,7 @@ except ImportError: from utils import ( any_requires_grad, assert_all_frozen, - calculate_bleu_score, + calculate_bleu, freeze_params, pickle_load, use_task_specific_params, @@ -261,7 +261,7 @@ class BartTranslationDistiller(BartSummarizationDistiller): self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] def calc_generative_metrics(self, preds, target) -> dict: - return calculate_bleu_score(preds, target) + return calculate_bleu(preds, target) @staticmethod def add_model_specific_args(parser, root_dir): diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 3d8e793a45..539b296142 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -23,7 +23,7 @@ try: Seq2SeqDataset, TranslationDataset, assert_all_frozen, - calculate_bleu_score, + calculate_bleu, calculate_rouge, flatten_list, freeze_params, @@ -42,7 +42,7 @@ except ImportError: Seq2SeqDataset, TranslationDataset, assert_all_frozen, - calculate_bleu_score, + calculate_bleu, calculate_rouge, flatten_list, freeze_params, @@ -325,7 +325,7 @@ class TranslationModule(SummarizationModule): self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang def calc_generative_metrics(self, preds, target) -> dict: - return calculate_bleu_score(preds, target) + return calculate_bleu(preds, target) def main(args, model=None) -> SummarizationModule: diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index 49b5a3448b..c4e6bffa6a 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer try: - from .utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params + from .utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params except ImportError: - from utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params + from utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -103,7 +103,7 @@ def run_generate(): if args.reference_path is None: return # Compute scores - score_fn = calculate_bleu_score if "translation" in args.task else calculate_rouge + score_fn = calculate_bleu if "translation" in args.task else calculate_rouge output_lns = [x.rstrip() for x in open(args.save_path).readlines()] reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)] scores: dict = score_fn(output_lns, reference_lns) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index b96158f21a..80a31f462a 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -57,9 +57,9 @@ def lmap(f: Callable, x: Iterable) -> List: return list(map(f, x)) -def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: +def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: """Uses sacrebleu's corpus_bleu implementation.""" - return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} + return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} def trim_batch( @@ -271,7 +271,7 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer aggregator.add_scores(scores) result = aggregator.aggregate() - return {k: v.mid.fmeasure * 100 for k, v in result.items()} + return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} def freeze_params(model: nn.Module):