[s2s] round bleu, rouge to 4 digits (#6704)
This commit is contained in:
@@ -23,7 +23,7 @@ try:
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu_score,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_params,
|
||||
@@ -42,7 +42,7 @@ except ImportError:
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu_score,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_params,
|
||||
@@ -325,7 +325,7 @@ class TranslationModule(SummarizationModule):
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
return calculate_bleu(preds, target)
|
||||
|
||||
|
||||
def main(args, model=None) -> SummarizationModule:
|
||||
|
||||
Reference in New Issue
Block a user