Fix gradient clipping for Sharded DDP (#9168)
* Fix gradient clipping for Sharded DDP * Fix typos in comments
This commit is contained in:
@@ -804,14 +804,23 @@ class Trainer:
|
|||||||
steps_in_epoch <= self.args.gradient_accumulation_steps
|
steps_in_epoch <= self.args.gradient_accumulation_steps
|
||||||
and (step + 1) == steps_in_epoch
|
and (step + 1) == steps_in_epoch
|
||||||
):
|
):
|
||||||
if self.use_amp:
|
# Gradient clipping
|
||||||
self.scaler.unscale_(self.optimizer)
|
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
if self.use_amp:
|
||||||
elif self.use_apex:
|
# AMP: gradients need unscaling
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
self.scaler.unscale_(self.optimizer)
|
||||||
else:
|
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
|
||||||
|
|
||||||
|
if hasattr(self.optimizer, "clip_grad_norm"):
|
||||||
|
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
|
||||||
|
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
|
||||||
|
else:
|
||||||
|
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
|
||||||
|
self.args.max_grad_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer step
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.optimizer_step(self.optimizer)
|
xm.optimizer_step(self.optimizer)
|
||||||
elif self.use_amp:
|
elif self.use_amp:
|
||||||
|
|||||||
Reference in New Issue
Block a user