diff --git a/examples/run_glue.py b/examples/run_glue.py index 57d1c56ac1..f8b17978eb 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -233,7 +233,11 @@ def train(args, train_dataset, model, tokenizer): loss.backward() tr_loss += loss.item() - if (step + 1) % args.gradient_accumulation_steps == 0: + if (step + 1) % args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + len(epoch_iterator) <= args.gradient_accumulation_steps + and (step + 1) == len(epoch_iterator) + ): if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: