From 3db4378bd7815330dc03506910ac7b5d78e19be1 Mon Sep 17 00:00:00 2001 From: Ritik Nandwal <48522685+nandwalritik@users.noreply.github.com> Date: Wed, 3 Aug 2022 18:03:18 +0530 Subject: [PATCH] Update no trainer scripts for language modeling and image classification examples (#18443) * Update no_trainer script for image-classification * Update no_trainer scripts for language-modeling examples * Remove unused variable * Removing truncation from losses array for language modeling examples --- .../run_image_classification_no_trainer.py | 10 +--------- .../pytorch/language-modeling/run_clm_no_trainer.py | 3 +-- .../pytorch/language-modeling/run_mlm_no_trainer.py | 3 +-- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py index b6dc53b716..f10a54add7 100644 --- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py +++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py @@ -508,19 +508,11 @@ def main(): break model.eval() - samples_seen = 0 for step, batch in enumerate(eval_dataloader): with torch.no_grad(): outputs = model(**batch) predictions = outputs.logits.argmax(dim=-1) - predictions, references = accelerator.gather((predictions, batch["labels"])) - # If we are in a multiprocess environment, the last batch has duplicates - if accelerator.num_processes > 1: - if step == len(eval_dataloader) - 1: - predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] - references = references[: len(eval_dataloader.dataset) - samples_seen] - else: - samples_seen += references.shape[0] + predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"])) metric.add_batch( predictions=predictions, references=references, diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 183ac24a0b..21dc568fd4 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -597,10 +597,9 @@ def main(): outputs = model(**batch) loss = outputs.loss - losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) + losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) losses = torch.cat(losses) - losses = losses[: len(eval_dataset)] try: eval_loss = torch.mean(losses) perplexity = math.exp(eval_loss) diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 07299d47e4..b7b085e5b6 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -642,10 +642,9 @@ def main(): outputs = model(**batch) loss = outputs.loss - losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) + losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) losses = torch.cat(losses) - losses = losses[: len(eval_dataset)] try: eval_loss = torch.mean(losses) perplexity = math.exp(eval_loss)