[s2s] round bleu, rouge to 4 digits (#6704)

This commit is contained in:
Sam Shleifer
2020-08-25 00:33:11 -04:00
committed by GitHub
parent b6512d2357
commit 0344428f79
4 changed files with 12 additions and 12 deletions

View File

@@ -57,9 +57,9 @@ def lmap(f: Callable, x: Iterable) -> List:
return list(map(f, x))
def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict:
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
"""Uses sacrebleu's corpus_bleu implementation."""
return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score}
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
def trim_batch(
@@ -271,7 +271,7 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
aggregator.add_scores(scores)
result = aggregator.aggregate()
return {k: v.mid.fmeasure * 100 for k, v in result.items()}
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
def freeze_params(model: nn.Module):