From 8cb4ecca251d2074b77d4318d5b0982b38286106 Mon Sep 17 00:00:00 2001 From: Noam Wies Date: Tue, 13 Oct 2020 16:46:44 +0300 Subject: [PATCH] Avoid unnecessary DDP synchronization when gradient_accumulation_steps > 1 (#7742) * use DDP no_sync when possible * fix is_nlp_available addition mistake * reformat trainer.py * reformat trainer.py * drop support for pytorch < 1.2 * return support for pytorch < 1.2 --- src/transformers/trainer.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b8e6e494b8..4a7d11d325 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -101,6 +101,11 @@ else: _use_native_amp = True from torch.cuda.amp import autocast +if version.parse(torch.__version__) < version.parse("1.2"): + _use_ddp_no_sync = False +else: + _use_ddp_no_sync = True + if is_datasets_available(): import datasets @@ -687,7 +692,15 @@ class Trainer: if (step + 1) % self.args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) - tr_loss += self.training_step(model, inputs) + if ( + ((step + 1) % self.args.gradient_accumulation_steps != 0) + and self.args.local_rank != -1 + and _use_ddp_no_sync + ): + with model.no_sync(): + tr_loss += self.training_step(model, inputs) + else: + tr_loss += self.training_step(model, inputs) self._total_flos += self.floating_point_ops(inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or (