Log the correct learning rate (#36973)

* fix learning rate log

* fix lr log

* add lr
This commit is contained in:
Marc Sun
2025-03-26 16:52:00 +01:00
committed by GitHub
parent 13d36e89fe
commit 79d6f9fd70

View File

@@ -2460,6 +2460,7 @@ class Trainer:
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None
learning_rate = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
if args.eval_on_start:
@@ -2608,6 +2609,9 @@ class Trainer:
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
# get leaning rate before update
learning_rate = self._get_learning_rate()
if not self.accelerator.optimizer_step_was_skipped:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
@@ -2618,7 +2622,14 @@ class Trainer:
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
tr_loss,
grad_norm,
model,
trial,
epoch,
ignore_keys_for_eval,
start_time,
learning_rate=learning_rate,
)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
@@ -2644,7 +2655,9 @@ class Trainer:
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
self._maybe_log_save_evaluate(
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_xla_available():
@@ -3064,7 +3077,9 @@ class Trainer:
) from exc
return metrics
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
def _maybe_log_save_evaluate(
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_xla_available():
xm.mark_step()
@@ -3080,6 +3095,9 @@ class Trainer:
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
if learning_rate is not None:
logs["learning_rate"] = learning_rate
else:
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar