Add on_optimizer_step to callback options (#31095)
* Modified test * Added on_optimizer_step to callbacks * Move callback after step is called * Added on optimizer step callback
This commit is contained in:
@@ -2306,6 +2306,8 @@ class Trainer:
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
||||
|
||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||
if optimizer_was_run:
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
|
||||
@@ -345,6 +345,12 @@ class TrainerCallback:
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
@@ -470,6 +476,9 @@ class CallbackHandler(TrainerCallback):
|
||||
control.should_save = False
|
||||
return self.call_event("on_step_begin", args, state, control)
|
||||
|
||||
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
return self.call_event("on_optimizer_step", args, state, control)
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
return self.call_event("on_substep_end", args, state, control)
|
||||
|
||||
|
||||
@@ -78,6 +78,9 @@ class MyTestTrainerCallback(TrainerCallback):
|
||||
def on_step_begin(self, args, state, control, **kwargs):
|
||||
self.events.append("on_step_begin")
|
||||
|
||||
def on_optimizer_step(self, args, state, control, **kwargs):
|
||||
self.events.append("on_optimizer_step")
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
self.events.append("on_step_end")
|
||||
|
||||
@@ -148,7 +151,7 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
expected_events.append("on_epoch_begin")
|
||||
for _ in range(train_dl_len):
|
||||
step += 1
|
||||
expected_events += ["on_step_begin", "on_step_end"]
|
||||
expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"]
|
||||
if step % trainer.args.logging_steps == 0:
|
||||
expected_events.append("on_log")
|
||||
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
||||
|
||||
Reference in New Issue
Block a user