Fix ValueError when eval_do_concat_batches=False with examples (#37621)

https://github.com/huggingface/transformers/issues/37593

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jeffhataws
2025-04-22 03:13:25 -07:00
committed by GitHub
parent 85665a4263
commit 964a1b6b7d
2 changed files with 8 additions and 1 deletions

View File

@@ -529,6 +529,9 @@ def main():
def compute_metrics(p):
predictions, labels = p
if not training_args.eval_do_concat_batches:
predictions = np.hstack(predictions)
labels = np.hstack(labels)
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)