Fix gather for metrics (#19360)

This commit is contained in:
Zachary Mueller
2022-10-05 14:52:01 -04:00
committed by GitHub
parent d9101b71bc
commit ad98642a82

View File

@@ -685,7 +685,7 @@ def main():
# If we did not pad to max length, we need to pad the labels too # If we did not pad to max length, we need to pad the labels too
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id) labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
generated_tokens, labels = accelerator.gather_for_metrics(generated_tokens, labels) generated_tokens, labels = accelerator.gather_for_metrics((generated_tokens, labels))
generated_tokens = generated_tokens.cpu().numpy() generated_tokens = generated_tokens.cpu().numpy()
labels = labels.cpu().numpy() labels = labels.cpu().numpy()