[s2s] round bleu, rouge to 4 digits (#6704)
This commit is contained in:
@@ -20,7 +20,7 @@ try:
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
any_requires_grad,
|
any_requires_grad,
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu_score,
|
calculate_bleu,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
pickle_load,
|
pickle_load,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
@@ -32,7 +32,7 @@ except ImportError:
|
|||||||
from utils import (
|
from utils import (
|
||||||
any_requires_grad,
|
any_requires_grad,
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu_score,
|
calculate_bleu,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
pickle_load,
|
pickle_load,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
@@ -261,7 +261,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
|
|||||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||||
|
|
||||||
def calc_generative_metrics(self, preds, target) -> dict:
|
def calc_generative_metrics(self, preds, target) -> dict:
|
||||||
return calculate_bleu_score(preds, target)
|
return calculate_bleu(preds, target)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ try:
|
|||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
TranslationDataset,
|
TranslationDataset,
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu_score,
|
calculate_bleu,
|
||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
flatten_list,
|
flatten_list,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
@@ -42,7 +42,7 @@ except ImportError:
|
|||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
TranslationDataset,
|
TranslationDataset,
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu_score,
|
calculate_bleu,
|
||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
flatten_list,
|
flatten_list,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
@@ -325,7 +325,7 @@ class TranslationModule(SummarizationModule):
|
|||||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||||
|
|
||||||
def calc_generative_metrics(self, preds, target) -> dict:
|
def calc_generative_metrics(self, preds, target) -> dict:
|
||||||
return calculate_bleu_score(preds, target)
|
return calculate_bleu(preds, target)
|
||||||
|
|
||||||
|
|
||||||
def main(args, model=None) -> SummarizationModule:
|
def main(args, model=None) -> SummarizationModule:
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params
|
from .utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params
|
from utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
|
||||||
|
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ def run_generate():
|
|||||||
if args.reference_path is None:
|
if args.reference_path is None:
|
||||||
return
|
return
|
||||||
# Compute scores
|
# Compute scores
|
||||||
score_fn = calculate_bleu_score if "translation" in args.task else calculate_rouge
|
score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
|
||||||
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
||||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
||||||
scores: dict = score_fn(output_lns, reference_lns)
|
scores: dict = score_fn(output_lns, reference_lns)
|
||||||
|
|||||||
@@ -57,9 +57,9 @@ def lmap(f: Callable, x: Iterable) -> List:
|
|||||||
return list(map(f, x))
|
return list(map(f, x))
|
||||||
|
|
||||||
|
|
||||||
def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict:
|
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
||||||
"""Uses sacrebleu's corpus_bleu implementation."""
|
"""Uses sacrebleu's corpus_bleu implementation."""
|
||||||
return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score}
|
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
|
||||||
|
|
||||||
|
|
||||||
def trim_batch(
|
def trim_batch(
|
||||||
@@ -271,7 +271,7 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
|
|||||||
aggregator.add_scores(scores)
|
aggregator.add_scores(scores)
|
||||||
|
|
||||||
result = aggregator.aggregate()
|
result = aggregator.aggregate()
|
||||||
return {k: v.mid.fmeasure * 100 for k, v in result.items()}
|
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
||||||
|
|
||||||
|
|
||||||
def freeze_params(model: nn.Module):
|
def freeze_params(model: nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user