add a callback hook right before the optimizer step (#33444)
This commit is contained in:
@@ -78,6 +78,9 @@ class MyTestTrainerCallback(TrainerCallback):
|
||||
def on_step_begin(self, args, state, control, **kwargs):
|
||||
self.events.append("on_step_begin")
|
||||
|
||||
def on_pre_optimizer_step(self, args, state, control, **kwargs):
|
||||
self.events.append("on_pre_optimizer_step")
|
||||
|
||||
def on_optimizer_step(self, args, state, control, **kwargs):
|
||||
self.events.append("on_optimizer_step")
|
||||
|
||||
@@ -151,7 +154,7 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
expected_events.append("on_epoch_begin")
|
||||
for _ in range(train_dl_len):
|
||||
step += 1
|
||||
expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"]
|
||||
expected_events += ["on_step_begin", "on_pre_optimizer_step", "on_optimizer_step", "on_step_end"]
|
||||
if step % trainer.args.logging_steps == 0:
|
||||
expected_events.append("on_log")
|
||||
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
||||
|
||||
Reference in New Issue
Block a user