fix run seq2seq (#10547)
This commit is contained in:
committed by
GitHub
parent
54e55b52d4
commit
395ffcd757
@@ -251,7 +251,7 @@ def main():
|
|||||||
pred_logits = pred.predictions
|
pred_logits = pred.predictions
|
||||||
pred_ids = np.argmax(pred_logits, axis=-1)
|
pred_ids = np.argmax(pred_logits, axis=-1)
|
||||||
|
|
||||||
pred.label_ids[pred.label_ids == -100] = 0
|
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
|
||||||
|
|
||||||
pred_str = processor.batch_decode(pred_ids)
|
pred_str = processor.batch_decode(pred_ids)
|
||||||
# we do not want to group tokens when computing the metrics
|
# we do not want to group tokens when computing the metrics
|
||||||
|
|||||||
Reference in New Issue
Block a user