From 77d6941e6425c3e92113265ebbb89936c53d2937 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 17 Dec 2020 09:44:24 -0500 Subject: [PATCH] Fix gradient clipping for Sharded DDP (#9168) * Fix gradient clipping for Sharded DDP * Fix typos in comments --- src/transformers/trainer.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d6c5a23d29..0bad522175 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -804,14 +804,23 @@ class Trainer: steps_in_epoch <= self.args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ): - if self.use_amp: - self.scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) - elif self.use_apex: - torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm) - else: - torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) + # Gradient clipping + if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0: + if self.use_amp: + # AMP: gradients need unscaling + self.scaler.unscale_(self.optimizer) + if hasattr(self.optimizer, "clip_grad_norm"): + # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping + self.optimizer.clip_grad_norm(self.args.max_grad_norm) + else: + # Revert to normal clipping otherwise, handling Apex or full precision + torch.nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer) if self.use_apex else model.parameters(), + self.args.max_grad_norm, + ) + + # Optimizer step if is_torch_tpu_available(): xm.optimizer_step(self.optimizer) elif self.use_amp: