Fix gather for metrics (#19360)
This commit is contained in:
@@ -685,7 +685,7 @@ def main():
|
|||||||
# If we did not pad to max length, we need to pad the labels too
|
# If we did not pad to max length, we need to pad the labels too
|
||||||
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
|
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
|
||||||
|
|
||||||
generated_tokens, labels = accelerator.gather_for_metrics(generated_tokens, labels)
|
generated_tokens, labels = accelerator.gather_for_metrics((generated_tokens, labels))
|
||||||
generated_tokens = generated_tokens.cpu().numpy()
|
generated_tokens = generated_tokens.cpu().numpy()
|
||||||
labels = labels.cpu().numpy()
|
labels = labels.cpu().numpy()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user