Fix gradient clipping for Sharded DDP (#9168)
* Fix gradient clipping for Sharded DDP * Fix typos in comments
This commit is contained in:
@@ -804,14 +804,23 @@ class Trainer:
|
||||
steps_in_epoch <= self.args.gradient_accumulation_steps
|
||||
and (step + 1) == steps_in_epoch
|
||||
):
|
||||
# Gradient clipping
|
||||
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0:
|
||||
if self.use_amp:
|
||||
# AMP: gradients need unscaling
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||
elif self.use_apex:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||
|
||||
if hasattr(self.optimizer, "clip_grad_norm"):
|
||||
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
|
||||
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
|
||||
else:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
|
||||
self.args.max_grad_norm,
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
if is_torch_tpu_available():
|
||||
xm.optimizer_step(self.optimizer)
|
||||
elif self.use_amp:
|
||||
|
||||
Reference in New Issue
Block a user