Fix multiproc metrics in no_trainer examples (#16865)

This commit is contained in:
Zachary Mueller
2022-04-20 17:26:27 -04:00
committed by GitHub
parent 175da8d182
commit 705d65368f
7 changed files with 79 additions and 14 deletions

View File

@@ -457,12 +457,21 @@ def main():
break
model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
metric.add_batch(
predictions=accelerator.gather(predictions),
references=accelerator.gather(batch["labels"]),
predictions=predictions,
references=references,
)
eval_metric = metric.compute()