Replace -100s in predictions by the pad token (#22693)
* Replace -100s in predictions by the pad token * Style * Try to catch them all
This commit is contained in:
@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
|
||||
|
||||
@@ -614,6 +615,8 @@ def main():
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
# Replace -100s used for padding as we can't decode them
|
||||
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
|
||||
Reference in New Issue
Block a user