Replace -100s in predictions by the pad token (#22693)

* Replace -100s in predictions by the pad token

* Style

* Try to catch them all
This commit is contained in:
Sylvain Gugger
2023-04-11 09:32:20 -04:00
committed by GitHub
parent ff73deeb0e
commit 1b1867d86b
3 changed files with 15 additions and 8 deletions

View File

@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple
import datasets
import evaluate
import numpy as np
from datasets import load_dataset
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
@@ -614,6 +615,8 @@ def main():
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
# Replace -100s used for padding as we can't decode them
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Build a map example to its corresponding features.