Add on_optimizer_step to callback options (#31095)

* Modified test

* Added on_optimizer_step to callbacks

* Move callback after step is called

* Added on optimizer step callback
This commit is contained in:
Dhruv Pai
2024-05-29 07:20:59 -07:00
committed by GitHub
parent 4af705c6ce
commit 5c88253556
3 changed files with 15 additions and 1 deletions

View File

@@ -78,6 +78,9 @@ class MyTestTrainerCallback(TrainerCallback):
def on_step_begin(self, args, state, control, **kwargs):
self.events.append("on_step_begin")
def on_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_optimizer_step")
def on_step_end(self, args, state, control, **kwargs):
self.events.append("on_step_end")
@@ -148,7 +151,7 @@ class TrainerCallbackTest(unittest.TestCase):
expected_events.append("on_epoch_begin")
for _ in range(train_dl_len):
step += 1
expected_events += ["on_step_begin", "on_step_end"]
expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"]
if step % trainer.args.logging_steps == 0:
expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: