add a callback hook right before the optimizer step (#33444)

This commit is contained in:
Wing Lian
2024-09-13 04:43:45 -04:00
committed by GitHub
parent 9c4639b622
commit 1027a532c5
3 changed files with 15 additions and 1 deletions

View File

@@ -2417,6 +2417,8 @@ class Trainer:
else: else:
grad_norm = _grad_norm grad_norm = _grad_norm
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
self.optimizer.step() self.optimizer.step()
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)

View File

@@ -344,6 +344,12 @@ class TrainerCallback:
""" """
pass pass
def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
"""
pass
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
""" """
Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients. Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
@@ -475,6 +481,9 @@ class CallbackHandler(TrainerCallback):
control.should_save = False control.should_save = False
return self.call_event("on_step_begin", args, state, control) return self.call_event("on_step_begin", args, state, control)
def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_pre_optimizer_step", args, state, control)
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_optimizer_step", args, state, control) return self.call_event("on_optimizer_step", args, state, control)

View File

@@ -78,6 +78,9 @@ class MyTestTrainerCallback(TrainerCallback):
def on_step_begin(self, args, state, control, **kwargs): def on_step_begin(self, args, state, control, **kwargs):
self.events.append("on_step_begin") self.events.append("on_step_begin")
def on_pre_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_pre_optimizer_step")
def on_optimizer_step(self, args, state, control, **kwargs): def on_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_optimizer_step") self.events.append("on_optimizer_step")
@@ -151,7 +154,7 @@ class TrainerCallbackTest(unittest.TestCase):
expected_events.append("on_epoch_begin") expected_events.append("on_epoch_begin")
for _ in range(train_dl_len): for _ in range(train_dl_len):
step += 1 step += 1
expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"] expected_events += ["on_step_begin", "on_pre_optimizer_step", "on_optimizer_step", "on_step_end"]
if step % trainer.args.logging_steps == 0: if step % trainer.args.logging_steps == 0:
expected_events.append("on_log") expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: