fix the backward for deepspeed (#9705)
This commit is contained in:
@@ -1282,8 +1282,7 @@ class Trainer:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.deepspeed:
|
||||
# calling on DS engine (model_wrapped == DDP(Deepspeed(PretrainedModule)))
|
||||
self.model_wrapped.module.backward(loss)
|
||||
self.deepspeed.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user