From ece6c514586cf925f1d12cf9c7d472aa6f85a9e0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 8 Feb 2021 10:08:16 -0500 Subject: [PATCH] [s2s examples] Replace -100 token ids with the tokenizer pad_id for compute_metrics (#10046) * replace -100 token ids with the tokenizer pad_id for compute_metrics * fixed typo for label_ids --- examples/seq2seq/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 303b89f781..2b4700e9f7 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -82,8 +82,11 @@ def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> return np.count_nonzero(tokens != tokenizer.pad_token_id) def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: - pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) - label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) + pred_ids = pred.predictions + label_ids = pred.label_ids + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_ids[label_ids == -100] = tokenizer.pad_token_id + label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) pred_str = lmap(str.strip, pred_str) label_str = lmap(str.strip, label_str) return pred_str, label_str