Replace -100s in predictions by the pad token (#22693)
* Replace -100s in predictions by the pad token * Style * Try to catch them all
This commit is contained in:
@@ -543,10 +543,10 @@ def main():
|
||||
preds, labels = eval_preds
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
# Replace -100s used for padding as we can't decode them
|
||||
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
if data_args.ignore_pad_token_for_loss:
|
||||
# Replace -100 in the labels as we can't decode them.
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
# Some simple post-processing
|
||||
@@ -626,8 +626,10 @@ def main():
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
if training_args.predict_with_generate:
|
||||
predictions = predict_results.predictions
|
||||
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
|
||||
predictions = tokenizer.batch_decode(
|
||||
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
predictions = [pred.strip() for pred in predictions]
|
||||
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
|
||||
|
||||
Reference in New Issue
Block a user