Fix ValueError when eval_do_concat_batches=False with examples (#37621)
https://github.com/huggingface/transformers/issues/37593 Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -529,6 +529,9 @@ def main():
|
||||
|
||||
def compute_metrics(p):
|
||||
predictions, labels = p
|
||||
if not training_args.eval_do_concat_batches:
|
||||
predictions = np.hstack(predictions)
|
||||
labels = np.hstack(labels)
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
|
||||
# Remove ignored index (special tokens)
|
||||
|
||||
Reference in New Issue
Block a user