From 8d1d1ffde25f6c7b3c472bd8f82f462b14ab8c11 Mon Sep 17 00:00:00 2001 From: Matthew Carrigan Date: Mon, 25 Mar 2019 12:15:19 +0000 Subject: [PATCH] Corrected the displayed loss when gradient_accumulation_steps > 1 --- examples/lm_finetuning/finetune_on_pregenerated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index e57710be3d..035f97b0c9 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -309,7 +309,7 @@ def main(): nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 pbar.update(1) - mean_loss = tr_loss / nb_tr_steps + mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps pbar.set_postfix_str(f"Loss: {mean_loss:.5f}") if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: