Log the correct learning rate (#36973)
* fix learning rate log * fix lr log * add lr
This commit is contained in:
@@ -2460,6 +2460,7 @@ class Trainer:
|
|||||||
self._globalstep_last_logged = self.state.global_step
|
self._globalstep_last_logged = self.state.global_step
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
grad_norm: Optional[float] = None
|
grad_norm: Optional[float] = None
|
||||||
|
learning_rate = None
|
||||||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
||||||
|
|
||||||
if args.eval_on_start:
|
if args.eval_on_start:
|
||||||
@@ -2608,6 +2609,9 @@ class Trainer:
|
|||||||
|
|
||||||
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
||||||
|
|
||||||
|
# get leaning rate before update
|
||||||
|
learning_rate = self._get_learning_rate()
|
||||||
|
|
||||||
if not self.accelerator.optimizer_step_was_skipped:
|
if not self.accelerator.optimizer_step_was_skipped:
|
||||||
# Delay optimizer scheduling until metrics are generated
|
# Delay optimizer scheduling until metrics are generated
|
||||||
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
@@ -2618,7 +2622,14 @@ class Trainer:
|
|||||||
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
||||||
self._maybe_log_save_evaluate(
|
self._maybe_log_save_evaluate(
|
||||||
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
|
tr_loss,
|
||||||
|
grad_norm,
|
||||||
|
model,
|
||||||
|
trial,
|
||||||
|
epoch,
|
||||||
|
ignore_keys_for_eval,
|
||||||
|
start_time,
|
||||||
|
learning_rate=learning_rate,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
||||||
@@ -2644,7 +2655,9 @@ class Trainer:
|
|||||||
self.control.should_training_stop = True
|
self.control.should_training_stop = True
|
||||||
|
|
||||||
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
||||||
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
|
self._maybe_log_save_evaluate(
|
||||||
|
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
|
||||||
|
)
|
||||||
|
|
||||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
@@ -3064,7 +3077,9 @@ class Trainer:
|
|||||||
) from exc
|
) from exc
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
|
def _maybe_log_save_evaluate(
|
||||||
|
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
|
||||||
|
):
|
||||||
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
@@ -3080,6 +3095,9 @@ class Trainer:
|
|||||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||||
if grad_norm is not None:
|
if grad_norm is not None:
|
||||||
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
||||||
|
if learning_rate is not None:
|
||||||
|
logs["learning_rate"] = learning_rate
|
||||||
|
else:
|
||||||
logs["learning_rate"] = self._get_learning_rate()
|
logs["learning_rate"] = self._get_learning_rate()
|
||||||
|
|
||||||
self._total_loss_scalar += tr_loss_scalar
|
self._total_loss_scalar += tr_loss_scalar
|
||||||
|
|||||||
Reference in New Issue
Block a user