From ed8fad73903c670d41a9dff173bc44995cda2d2f Mon Sep 17 00:00:00 2001 From: Mathieu Prouveur Date: Wed, 24 Apr 2019 14:07:00 +0200 Subject: [PATCH] Update example files so that tr_loss is not affected by args.gradient_accumulation_step --- examples/run_classifier.py | 2 +- examples/run_swag.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index b90ac494e4..e14788cacb 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -845,7 +845,7 @@ def main(): else: loss.backward() - tr_loss += loss.item() + tr_loss += loss.item() * args.gradient_accumulation_steps nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: diff --git a/examples/run_swag.py b/examples/run_swag.py index a6cfdbe311..5a65d7a748 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -452,7 +452,7 @@ def main(): loss = loss * args.loss_scale if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps - tr_loss += loss.item() + tr_loss += loss.item() * args.gradient_accumulation_steps nb_tr_examples += input_ids.size(0) nb_tr_steps += 1