Fix gather for metrics (#19360)
This commit is contained in:
@@ -685,7 +685,7 @@ def main():
|
||||
# If we did not pad to max length, we need to pad the labels too
|
||||
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
|
||||
|
||||
generated_tokens, labels = accelerator.gather_for_metrics(generated_tokens, labels)
|
||||
generated_tokens, labels = accelerator.gather_for_metrics((generated_tokens, labels))
|
||||
generated_tokens = generated_tokens.cpu().numpy()
|
||||
labels = labels.cpu().numpy()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user