Fix GA loss for Deepspeed (#35808)
* Fix GA loss for Deepspeed * Turn off loss scaling in DeepSpeed engine by scale_wrt_gas * Add comment linking to PR
This commit is contained in:
@@ -3722,6 +3722,11 @@ class Trainer:
|
||||
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
|
||||
# https://github.com/huggingface/transformers/pull/35808
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
kwargs["scale_wrt_gas"] = False
|
||||
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
|
||||
return loss.detach()
|
||||
|
||||
Reference in New Issue
Block a user