From af1ee9e648e644aae45c5250a46d4c25a4e0d04d Mon Sep 17 00:00:00 2001 From: Victor SANH Date: Wed, 8 Jan 2020 15:47:53 -0500 Subject: [PATCH] Move `torch.nn.utils.clip_grad_norm_` --- examples/distillation/run_squad_w_distillation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/distillation/run_squad_w_distillation.py b/examples/distillation/run_squad_w_distillation.py index 91d3802f6b..e71e4ffbbd 100644 --- a/examples/distillation/run_squad_w_distillation.py +++ b/examples/distillation/run_squad_w_distillation.py @@ -204,13 +204,16 @@ def train(args, train_dataset, model, tokenizer, teacher=None): if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad()