add a callback hook right before the optimizer step (#33444)
This commit is contained in:
@@ -2417,6 +2417,8 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
grad_norm = _grad_norm
|
grad_norm = _grad_norm
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
|
||||||
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
||||||
|
|||||||
@@ -344,6 +344,12 @@ class TrainerCallback:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
"""
|
"""
|
||||||
Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
|
Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
|
||||||
@@ -475,6 +481,9 @@ class CallbackHandler(TrainerCallback):
|
|||||||
control.should_save = False
|
control.should_save = False
|
||||||
return self.call_event("on_step_begin", args, state, control)
|
return self.call_event("on_step_begin", args, state, control)
|
||||||
|
|
||||||
|
def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
return self.call_event("on_pre_optimizer_step", args, state, control)
|
||||||
|
|
||||||
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
return self.call_event("on_optimizer_step", args, state, control)
|
return self.call_event("on_optimizer_step", args, state, control)
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,9 @@ class MyTestTrainerCallback(TrainerCallback):
|
|||||||
def on_step_begin(self, args, state, control, **kwargs):
|
def on_step_begin(self, args, state, control, **kwargs):
|
||||||
self.events.append("on_step_begin")
|
self.events.append("on_step_begin")
|
||||||
|
|
||||||
|
def on_pre_optimizer_step(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_pre_optimizer_step")
|
||||||
|
|
||||||
def on_optimizer_step(self, args, state, control, **kwargs):
|
def on_optimizer_step(self, args, state, control, **kwargs):
|
||||||
self.events.append("on_optimizer_step")
|
self.events.append("on_optimizer_step")
|
||||||
|
|
||||||
@@ -151,7 +154,7 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
expected_events.append("on_epoch_begin")
|
expected_events.append("on_epoch_begin")
|
||||||
for _ in range(train_dl_len):
|
for _ in range(train_dl_len):
|
||||||
step += 1
|
step += 1
|
||||||
expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"]
|
expected_events += ["on_step_begin", "on_pre_optimizer_step", "on_optimizer_step", "on_step_end"]
|
||||||
if step % trainer.args.logging_steps == 0:
|
if step % trainer.args.logging_steps == 0:
|
||||||
expected_events.append("on_log")
|
expected_events.append("on_log")
|
||||||
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user