Fix ValueError when eval_do_concat_batches=False with examples (#37621)

https://github.com/huggingface/transformers/issues/37593

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jeffhataws
2025-04-22 03:13:25 -07:00
committed by GitHub
parent 85665a4263
commit 964a1b6b7d
2 changed files with 8 additions and 1 deletions

View File

@@ -508,8 +508,12 @@ def main():
# predictions and label_ids field) and has to return a dictionary string to float.
def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
labels = p.label_ids
if not training_args.eval_do_concat_batches:
preds = np.concatenate(preds, axis=0)
labels = np.concatenate(p.label_ids, axis=0)
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
result = metric.compute(predictions=preds, references=p.label_ids)
result = metric.compute(predictions=preds, references=labels)
if len(result) > 1:
result["combined_score"] = np.mean(list(result.values())).item()
return result

View File

@@ -529,6 +529,9 @@ def main():
def compute_metrics(p):
predictions, labels = p
if not training_args.eval_do_concat_batches:
predictions = np.hstack(predictions)
labels = np.hstack(labels)
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)