From 79d6f9fd7006bd6a7860f63f43f0e6e89a0412a4 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 26 Mar 2025 16:52:00 +0100 Subject: [PATCH] Log the correct learning rate (#36973) * fix learning rate log * fix lr log * add lr --- src/transformers/trainer.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f843d9be75..189f09d3de 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2460,6 +2460,7 @@ class Trainer: self._globalstep_last_logged = self.state.global_step model.zero_grad() grad_norm: Optional[float] = None + learning_rate = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) if args.eval_on_start: @@ -2608,6 +2609,9 @@ class Trainer: self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + # get leaning rate before update + learning_rate = self._get_learning_rate() + if not self.accelerator.optimizer_step_was_skipped: # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): @@ -2618,7 +2622,14 @@ class Trainer: self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate( - tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time + tr_loss, + grad_norm, + model, + trial, + epoch, + ignore_keys_for_eval, + start_time, + learning_rate=learning_rate, ) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) @@ -2644,7 +2655,9 @@ class Trainer: self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time) + self._maybe_log_save_evaluate( + tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate + ) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if is_torch_xla_available(): @@ -3064,7 +3077,9 @@ class Trainer: ) from exc return metrics - def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time): + def _maybe_log_save_evaluate( + self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None + ): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: if is_torch_xla_available(): xm.mark_step() @@ -3080,7 +3095,10 @@ class Trainer: logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) if grad_norm is not None: logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm - logs["learning_rate"] = self._get_learning_rate() + if learning_rate is not None: + logs["learning_rate"] = learning_rate + else: + logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step