From f6fe41c96bf76b4324dc84fec15e3e7b861b7428 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Wed, 18 Nov 2020 15:58:08 +0100 Subject: [PATCH] Reset loss to zero on logging in Trainer to avoid bfloat16 issues (#8561) * make tr_loss regular float * Revert "make tr_loss regular float" This reverts commit c9d7ccfaf0c4387187b0841694f01ec0ffd5f4ba. * reset loss at each logging step * keep track of total loss with _total_loss_scalar * add remaining tr_loss at the end --- src/transformers/trainer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 72f8f7d985..0e1eef74cc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -696,8 +696,10 @@ class Trainer: self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() + # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(self.args.device) - self._logging_loss_scalar = 0 + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 self._globalstep_last_logged = 0 self._total_flos = self.state.total_flos model.zero_grad() @@ -812,23 +814,26 @@ class Trainer: self.log({"total_flos": self.state.total_flos}) self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() - return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) + return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step) def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): if self.control.should_log: logs: Dict[str, float] = {} tr_loss_scalar = tr_loss.item() - logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / ( - self.state.global_step - self._globalstep_last_logged - ) + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged) # backward compatibility for pytorch schedulers logs["learning_rate"] = ( self.lr_scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else self.lr_scheduler.get_lr()[0] ) - self._logging_loss_scalar = tr_loss_scalar + self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step self.log(logs)