From 12b66215cfa25d03a7e9f0a8ed001d0235a77730 Mon Sep 17 00:00:00 2001 From: lewtun Date: Fri, 5 Mar 2021 13:44:53 +0100 Subject: [PATCH] Fix example of custom Trainer to reflect signature of compute_loss (#10537) --- docs/source/main_classes/trainer.rst | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index a6edaccf3e..4c3bc64f03 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -23,14 +23,14 @@ customization during training. The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex `__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow. -Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the -previous features. To inject custom behavior you can subclass them and override the following methods: +Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop which supports +the above features. To inject custom behavior you can subclass them and override the following methods: - **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset. - **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaluation DataLoader (PyTorch) or TF Dataset. - **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset. - **log** -- Logs information on the various objects watching training. -- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at +- **create_optimizer_and_scheduler** -- Sets up the optimizer and learning rate scheduler if they were not passed at init. - **compute_loss** - Computes the loss on a batch of training inputs. - **training_step** -- Performs a training step. @@ -39,17 +39,23 @@ previous features. To inject custom behavior you can subclass them and override - **evaluate** -- Runs an evaluation loop and returns metrics. - **predict** -- Returns predictions (with metrics if labels are available) on a test set. -Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function: +Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function for multi-label +classification: .. code-block:: python + import torch from transformers import Trainer - class MyTrainer(Trainer): - def compute_loss(self, model, inputs): + + class MultilabelTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels") outputs = model(**inputs) - logits = outputs[0] - return my_custom_loss(logits, labels) + logits = outputs.logits + loss_fct = torch.nn.BCEWithLogitsLoss() + loss = loss_fct(logits.view(-1, self.model.config.num_labels), + labels.float().view(-1, self.model.config.num_labels)) + return (loss, outputs) if return_outputs else loss Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use :doc:`callbacks ` that can inspect the training loop state (for progress reporting, logging on TensorBoard or