Avoid unnecessary DDP synchronization when gradient_accumulation_steps > 1 (#7742)
* use DDP no_sync when possible * fix is_nlp_available addition mistake * reformat trainer.py * reformat trainer.py * drop support for pytorch < 1.2 * return support for pytorch < 1.2
This commit is contained in:
@@ -101,6 +101,11 @@ else:
|
||||
_use_native_amp = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.2"):
|
||||
_use_ddp_no_sync = False
|
||||
else:
|
||||
_use_ddp_no_sync = True
|
||||
|
||||
if is_datasets_available():
|
||||
import datasets
|
||||
|
||||
@@ -687,6 +692,14 @@ class Trainer:
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
||||
|
||||
if (
|
||||
((step + 1) % self.args.gradient_accumulation_steps != 0)
|
||||
and self.args.local_rank != -1
|
||||
and _use_ddp_no_sync
|
||||
):
|
||||
with model.no_sync():
|
||||
tr_loss += self.training_step(model, inputs)
|
||||
else:
|
||||
tr_loss += self.training_step(model, inputs)
|
||||
self._total_flos += self.floating_point_ops(inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user