fixed a typo (#16508)
This commit is contained in:
@@ -791,7 +791,7 @@ def main():
|
|||||||
|
|
||||||
if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered
|
if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered
|
||||||
start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
|
start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
|
||||||
end_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
|
end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
|
||||||
|
|
||||||
all_start_logits.append(accelerator.gather(start_logits).cpu().numpy())
|
all_start_logits.append(accelerator.gather(start_logits).cpu().numpy())
|
||||||
all_end_logits.append(accelerator.gather(end_logits).cpu().numpy())
|
all_end_logits.append(accelerator.gather(end_logits).cpu().numpy())
|
||||||
|
|||||||
Reference in New Issue
Block a user