fix the backward for deepspeed (#9705)

This commit is contained in:
Stas Bekman
2021-01-20 09:07:07 -08:00
committed by GitHub
parent 538245b0c2
commit cd5565bed3

View File

@@ -1282,8 +1282,7 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
# calling on DS engine (model_wrapped == DDP(Deepspeed(PretrainedModule)))
self.model_wrapped.module.backward(loss)
self.deepspeed.backward(loss)
else:
loss.backward()