Fix multiproc metrics in no_trainer examples (#16865)
This commit is contained in:
@@ -658,6 +658,7 @@ def main():
|
||||
break
|
||||
|
||||
model.eval()
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
@@ -666,9 +667,14 @@ def main():
|
||||
if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered
|
||||
predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
|
||||
labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
||||
|
||||
predictions_gathered = accelerator.gather(predictions)
|
||||
labels_gathered = accelerator.gather(labels)
|
||||
predictions_gathered, labels_gathered = accelerator.gather((predictions, labels))
|
||||
# If we are in a multiprocess environment, the last batch has duplicates
|
||||
if accelerator.num_processes > 1:
|
||||
if step == len(eval_dataloader):
|
||||
predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen]
|
||||
labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen]
|
||||
else:
|
||||
samples_seen += labels_gathered.shape[0]
|
||||
preds, refs = get_labels(predictions_gathered, labels_gathered)
|
||||
metric.add_batch(
|
||||
predictions=preds,
|
||||
|
||||
Reference in New Issue
Block a user