[s2s] round bleu, rouge to 4 digits (#6704)
This commit is contained in:
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from .utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params
|
||||
from .utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
|
||||
except ImportError:
|
||||
from utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params
|
||||
from utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -103,7 +103,7 @@ def run_generate():
|
||||
if args.reference_path is None:
|
||||
return
|
||||
# Compute scores
|
||||
score_fn = calculate_bleu_score if "translation" in args.task else calculate_rouge
|
||||
score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
|
||||
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
|
||||
scores: dict = score_fn(output_lns, reference_lns)
|
||||
|
||||
Reference in New Issue
Block a user