Fix multiproc metrics in no_trainer examples (#16865)
This commit is contained in:
@@ -628,6 +628,7 @@ def main():
|
||||
"max_length": args.val_max_target_length if args is not None else config.max_length,
|
||||
"num_beams": args.num_beams,
|
||||
}
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
generated_tokens = accelerator.unwrap_model(model).generate(
|
||||
@@ -644,8 +645,9 @@ def main():
|
||||
# If we did not pad to max length, we need to pad the labels too
|
||||
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
|
||||
|
||||
generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
|
||||
labels = accelerator.gather(labels).cpu().numpy()
|
||||
generated_tokens, labels = accelerator.gather((generated_tokens, labels))
|
||||
generated_tokens = generated_tokens.cpu().numpy()
|
||||
labels = labels.cpu().numpy()
|
||||
|
||||
if args.ignore_pad_token_for_loss:
|
||||
# Replace -100 in the labels as we can't decode them.
|
||||
@@ -656,8 +658,18 @@ def main():
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
# If we are in a multiprocess environment, the last batch has duplicates
|
||||
if accelerator.num_processes > 1:
|
||||
if step == len(eval_dataloader):
|
||||
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
|
||||
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
|
||||
else:
|
||||
samples_seen += decoded_labels.shape[0]
|
||||
|
||||
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
|
||||
metric.add_batch(
|
||||
predictions=decoded_preds,
|
||||
references=decoded_labels,
|
||||
)
|
||||
result = metric.compute(use_stemmer=True)
|
||||
# Extract a few results from ROUGE
|
||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
|
||||
Reference in New Issue
Block a user