Log the correct learning rate (#36973)
* fix learning rate log * fix lr log * add lr
This commit is contained in:
@@ -2460,6 +2460,7 @@ class Trainer:
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
model.zero_grad()
|
||||
grad_norm: Optional[float] = None
|
||||
learning_rate = None
|
||||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
||||
|
||||
if args.eval_on_start:
|
||||
@@ -2608,6 +2609,9 @@ class Trainer:
|
||||
|
||||
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
||||
|
||||
# get leaning rate before update
|
||||
learning_rate = self._get_learning_rate()
|
||||
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
@@ -2618,7 +2622,14 @@ class Trainer:
|
||||
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
||||
self._maybe_log_save_evaluate(
|
||||
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
|
||||
tr_loss,
|
||||
grad_norm,
|
||||
model,
|
||||
trial,
|
||||
epoch,
|
||||
ignore_keys_for_eval,
|
||||
start_time,
|
||||
learning_rate=learning_rate,
|
||||
)
|
||||
else:
|
||||
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
||||
@@ -2644,7 +2655,9 @@ class Trainer:
|
||||
self.control.should_training_stop = True
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
||||
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
|
||||
self._maybe_log_save_evaluate(
|
||||
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
|
||||
)
|
||||
|
||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
||||
if is_torch_xla_available():
|
||||
@@ -3064,7 +3077,9 @@ class Trainer:
|
||||
) from exc
|
||||
return metrics
|
||||
|
||||
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
|
||||
def _maybe_log_save_evaluate(
|
||||
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
|
||||
):
|
||||
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
@@ -3080,7 +3095,10 @@ class Trainer:
|
||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||
if grad_norm is not None:
|
||||
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
||||
logs["learning_rate"] = self._get_learning_rate()
|
||||
if learning_rate is not None:
|
||||
logs["learning_rate"] = learning_rate
|
||||
else:
|
||||
logs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
self._total_loss_scalar += tr_loss_scalar
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
||||
Reference in New Issue
Block a user