From c28d04e9e252a1a099944e325685f14d242ecdcd Mon Sep 17 00:00:00 2001 From: Divyanshu Kumar <53843818+divyanshugit@users.noreply.github.com> Date: Mon, 3 Oct 2022 18:51:51 +0530 Subject: [PATCH] Update no_trainer script for summarization (#19277) * Update no_trainer script for summarization * removed unnecessary import * fixes notation mistake * removed: unused variable --- .../summarization/run_summarization_no_trainer.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index eb809ada99..594d468330 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -669,7 +669,6 @@ def main(): "max_length": args.val_max_target_length if args is not None else config.max_length, "num_beams": args.num_beams, } - samples_seen = 0 for step, batch in enumerate(eval_dataloader): with torch.no_grad(): generated_tokens = accelerator.unwrap_model(model).generate( @@ -686,7 +685,7 @@ def main(): # If we did not pad to max length, we need to pad the labels too labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id) - generated_tokens, labels = accelerator.gather((generated_tokens, labels)) + generated_tokens, labels = accelerator.gather_for_metrics(generated_tokens, labels) generated_tokens = generated_tokens.cpu().numpy() labels = labels.cpu().numpy() @@ -699,14 +698,8 @@ def main(): decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) - # If we are in a multiprocess environment, the last batch has duplicates - if accelerator.num_processes > 1: - if step == len(eval_dataloader) - 1: - decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen] - decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen] - else: - samples_seen += len(decoded_labels) + decoded_preds, decoded_labels = accelerator.gather_for_metrics(decoded_preds, decoded_labels) metric.add_batch( predictions=decoded_preds, references=decoded_labels,