fixed a typo (#16508)

This commit is contained in:
Bhadresh Savani
2022-03-31 17:19:02 +05:30
committed by GitHub
parent 6a4dbba1a3
commit 05b4c32908

View File

@@ -791,7 +791,7 @@ def main():
if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered
start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100) start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
end_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100) end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
all_start_logits.append(accelerator.gather(start_logits).cpu().numpy()) all_start_logits.append(accelerator.gather(start_logits).cpu().numpy())
all_end_logits.append(accelerator.gather(end_logits).cpu().numpy()) all_end_logits.append(accelerator.gather(end_logits).cpu().numpy())