From e1c02e018c09a6ddb6d5465d6dff0f3816aa8501 Mon Sep 17 00:00:00 2001 From: Amala Deshmukh Date: Mon, 5 Apr 2021 12:27:23 -0400 Subject: [PATCH] Add example for registering callbacks with trainers (#10928) * Add example for callback registry Resolves: #9036 * Update callback registry documentation * Added comments for other ways to register callback --- docs/source/main_classes/callback.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/source/main_classes/callback.rst b/docs/source/main_classes/callback.rst index 464c41ff82..3a7934bdce 100644 --- a/docs/source/main_classes/callback.rst +++ b/docs/source/main_classes/callback.rst @@ -74,6 +74,32 @@ TrainerCallback .. autoclass:: transformers.TrainerCallback :members: +Here is an example of how to register a custom callback with the PyTorch :class:`~transformers.Trainer`: + +.. code-block:: python + + class MyCallback(TrainerCallback): + "A callback that prints a message at the beginning of training" + + def on_train_begin(self, args, state, control, **kwargs): + print("Starting training") + + trainer = Trainer( + model, + args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + callbacks=[MyCallback] # We can either pass the callback class this way or an instance of it (MyCallback()) + ) + +Another way to register a callback is to call ``trainer.add_callback()`` as follows: + +.. code-block:: python + + trainer = Trainer(...) + trainer.add_callback(MyCallback) + # Alternatively, we can pass an instance of the callback class + trainer.add_callback(MyCallback()) TrainerState ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~