From 5c88253556b7f15cf7d7e9793d7b2a39b4aa588a Mon Sep 17 00:00:00 2001 From: Dhruv Pai <46631243+dhruvbpai@users.noreply.github.com> Date: Wed, 29 May 2024 07:20:59 -0700 Subject: [PATCH] Add on_optimizer_step to callback options (#31095) * Modified test * Added on_optimizer_step to callbacks * Move callback after step is called * Added on optimizer step callback --- src/transformers/trainer.py | 2 ++ src/transformers/trainer_callback.py | 9 +++++++++ tests/trainer/test_trainer_callback.py | 5 ++++- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 58e5fd14b6..49e7803066 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2306,6 +2306,8 @@ class Trainer: self.optimizer.step() + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: # Delay optimizer scheduling until metrics are generated diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 45ecf7c80c..207d8ebdff 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -345,6 +345,12 @@ class TrainerCallback: """ pass + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients. + """ + pass + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ Event called at the end of an substep during gradient accumulation. @@ -470,6 +476,9 @@ class CallbackHandler(TrainerCallback): control.should_save = False return self.call_event("on_step_begin", args, state, control) + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_optimizer_step", args, state, control) + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): return self.call_event("on_substep_end", args, state, control) diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 9eeb1d5e41..edd73f29dc 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -78,6 +78,9 @@ class MyTestTrainerCallback(TrainerCallback): def on_step_begin(self, args, state, control, **kwargs): self.events.append("on_step_begin") + def on_optimizer_step(self, args, state, control, **kwargs): + self.events.append("on_optimizer_step") + def on_step_end(self, args, state, control, **kwargs): self.events.append("on_step_end") @@ -148,7 +151,7 @@ class TrainerCallbackTest(unittest.TestCase): expected_events.append("on_epoch_begin") for _ in range(train_dl_len): step += 1 - expected_events += ["on_step_begin", "on_step_end"] + expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"] if step % trainer.args.logging_steps == 0: expected_events.append("on_log") if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: