Fix gradient clipping for Sharded DDP (#9168)

* Fix gradient clipping for Sharded DDP

* Fix typos in comments
This commit is contained in:
Sylvain Gugger
2020-12-17 09:44:24 -05:00
committed by GitHub
parent 1aca3d6afa
commit 77d6941e64

View File

@@ -804,14 +804,23 @@ class Trainer:
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
# Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0:
if self.use_amp:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
self.args.max_grad_norm,
)
# Optimizer step
if is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
elif self.use_amp: