Avoid unnecessary DDP synchronization when gradient_accumulation_steps > 1 (#7742)
* use DDP no_sync when possible * fix is_nlp_available addition mistake * reformat trainer.py * reformat trainer.py * drop support for pytorch < 1.2 * return support for pytorch < 1.2
This commit is contained in:
@@ -101,6 +101,11 @@ else:
|
|||||||
_use_native_amp = True
|
_use_native_amp = True
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.2"):
|
||||||
|
_use_ddp_no_sync = False
|
||||||
|
else:
|
||||||
|
_use_ddp_no_sync = True
|
||||||
|
|
||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
@@ -687,7 +692,15 @@ class Trainer:
|
|||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
||||||
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
tr_loss += self.training_step(model, inputs)
|
if (
|
||||||
|
((step + 1) % self.args.gradient_accumulation_steps != 0)
|
||||||
|
and self.args.local_rank != -1
|
||||||
|
and _use_ddp_no_sync
|
||||||
|
):
|
||||||
|
with model.no_sync():
|
||||||
|
tr_loss += self.training_step(model, inputs)
|
||||||
|
else:
|
||||||
|
tr_loss += self.training_step(model, inputs)
|
||||||
self._total_flos += self.floating_point_ops(inputs)
|
self._total_flos += self.floating_point_ops(inputs)
|
||||||
|
|
||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||||
|
|||||||
Reference in New Issue
Block a user