Trainer callbacks (#7596)
* Initial callback proposal * Finish various callbacks * Post-rebase conflicts * Fix tests * Don't use something that's not set * Documentation * Remove unwanted print. * Document all models can work * Add tests + small fixes * Update docs/source/internal/trainer_utils.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Fix TF tests * Real fix this time * This one should work * Fix typo * Really fix typo Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
68
docs/source/main_classes/callback.rst
Normal file
68
docs/source/main_classes/callback.rst
Normal file
@@ -0,0 +1,68 @@
|
||||
Callbacks
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Callbacks are objects that can customize the behavior of the training loop in the PyTorch
|
||||
:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop
|
||||
state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
|
||||
stopping).
|
||||
|
||||
Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they
|
||||
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
|
||||
subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples).
|
||||
|
||||
By default a :class:`~transformers.Trainer` will use the following callbacks:
|
||||
|
||||
- :class:`~transformers.DefaultFlowCallback` which handles the default beahvior for logging, saving and evaluation.
|
||||
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
|
||||
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
|
||||
it's the second one).
|
||||
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
|
||||
or tensorboardX).
|
||||
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
|
||||
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
|
||||
|
||||
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
|
||||
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
|
||||
Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via
|
||||
:class:`~transformers.TrainerControl`.
|
||||
|
||||
|
||||
Available Callbacks
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Here is the list of the available :class:`~transformers.TrainerCallback` in the library:
|
||||
|
||||
.. autoclass:: transformers.integrations.CometCallback
|
||||
:members: setup
|
||||
|
||||
.. autoclass:: transformers.DefaultFlowCallback
|
||||
|
||||
.. autoclass:: transformers.PrinterCallback
|
||||
|
||||
.. autoclass:: transformers.ProgressCallback
|
||||
|
||||
.. autoclass:: transformers.integrations.TensorBoardCallback
|
||||
|
||||
.. autoclass:: transformers.integrations.WandbCallback
|
||||
:members: setup
|
||||
|
||||
|
||||
TrainerCallback
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrainerCallback
|
||||
:members:
|
||||
|
||||
|
||||
TrainerState
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrainerState
|
||||
:members:
|
||||
|
||||
|
||||
TrainerControl
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrainerControl
|
||||
:members:
|
||||
@@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override
|
||||
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
|
||||
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
|
||||
- **log** -- Logs information on the various objects watching training.
|
||||
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
|
||||
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
||||
init.
|
||||
- **compute_loss** - Computes the loss on a batch of training inputs.
|
||||
@@ -40,6 +39,10 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu
|
||||
logits = outputs[0]
|
||||
return my_custom_loss(logits, labels)
|
||||
|
||||
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
|
||||
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
|
||||
other ML platforms...) and take decisions (like early stopping).
|
||||
|
||||
|
||||
Trainer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -47,29 +50,23 @@ Trainer
|
||||
.. autoclass:: transformers.Trainer
|
||||
:members:
|
||||
|
||||
|
||||
TFTrainer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFTrainer
|
||||
:members:
|
||||
|
||||
|
||||
TrainingArguments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrainingArguments
|
||||
:members:
|
||||
|
||||
|
||||
TFTrainingArguments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFTrainingArguments
|
||||
:members:
|
||||
|
||||
Utilities
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.EvalPrediction
|
||||
|
||||
.. autofunction:: transformers.set_seed
|
||||
|
||||
.. autofunction:: transformers.torch_distributed_zero_first
|
||||
|
||||
Reference in New Issue
Block a user