diff --git a/examples/distillation/run_squad_w_distillation.py b/examples/distillation/run_squad_w_distillation.py index 91d3802f6b..e71e4ffbbd 100644 --- a/examples/distillation/run_squad_w_distillation.py +++ b/examples/distillation/run_squad_w_distillation.py @@ -204,13 +204,16 @@ def train(args, train_dataset, model, tokenizer, teacher=None): if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad()