Reset loss to zero on logging in Trainer to avoid bfloat16 issues (#8561)
* make tr_loss regular float * Revert "make tr_loss regular float" This reverts commit c9d7ccfaf0c4387187b0841694f01ec0ffd5f4ba. * reset loss at each logging step * keep track of total loss with _total_loss_scalar * add remaining tr_loss at the end
This commit is contained in:
committed by
GitHub
parent
b592728eff
commit
f6fe41c96b
@@ -696,8 +696,10 @@ class Trainer:
|
||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||
|
||||
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
|
||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||
self._logging_loss_scalar = 0
|
||||
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
|
||||
self._total_loss_scalar = 0.0
|
||||
self._globalstep_last_logged = 0
|
||||
self._total_flos = self.state.total_flos
|
||||
model.zero_grad()
|
||||
@@ -812,23 +814,26 @@ class Trainer:
|
||||
self.log({"total_flos": self.state.total_flos})
|
||||
|
||||
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
# add remaining tr_loss
|
||||
self._total_loss_scalar += tr_loss.item()
|
||||
|
||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
|
||||
|
||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
||||
if self.control.should_log:
|
||||
logs: Dict[str, float] = {}
|
||||
tr_loss_scalar = tr_loss.item()
|
||||
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / (
|
||||
self.state.global_step - self._globalstep_last_logged
|
||||
)
|
||||
# reset tr_loss to zero
|
||||
tr_loss -= tr_loss
|
||||
|
||||
logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
|
||||
# backward compatibility for pytorch schedulers
|
||||
logs["learning_rate"] = (
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||
else self.lr_scheduler.get_lr()[0]
|
||||
)
|
||||
self._logging_loss_scalar = tr_loss_scalar
|
||||
self._total_loss_scalar += tr_loss_scalar
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
||||
self.log(logs)
|
||||
|
||||
Reference in New Issue
Block a user