Fix example of custom Trainer to reflect signature of compute_loss (#10537)
This commit is contained in:
@@ -23,14 +23,14 @@ customization during training.
|
|||||||
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
|
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
|
||||||
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
|
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
|
||||||
|
|
||||||
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the
|
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop which supports
|
||||||
previous features. To inject custom behavior you can subclass them and override the following methods:
|
the above features. To inject custom behavior you can subclass them and override the following methods:
|
||||||
|
|
||||||
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
|
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
|
||||||
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaluation DataLoader (PyTorch) or TF Dataset.
|
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaluation DataLoader (PyTorch) or TF Dataset.
|
||||||
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
|
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
|
||||||
- **log** -- Logs information on the various objects watching training.
|
- **log** -- Logs information on the various objects watching training.
|
||||||
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
- **create_optimizer_and_scheduler** -- Sets up the optimizer and learning rate scheduler if they were not passed at
|
||||||
init.
|
init.
|
||||||
- **compute_loss** - Computes the loss on a batch of training inputs.
|
- **compute_loss** - Computes the loss on a batch of training inputs.
|
||||||
- **training_step** -- Performs a training step.
|
- **training_step** -- Performs a training step.
|
||||||
@@ -39,17 +39,23 @@ previous features. To inject custom behavior you can subclass them and override
|
|||||||
- **evaluate** -- Runs an evaluation loop and returns metrics.
|
- **evaluate** -- Runs an evaluation loop and returns metrics.
|
||||||
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
|
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
|
||||||
|
|
||||||
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function:
|
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function for multi-label
|
||||||
|
classification:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
class MyTrainer(Trainer):
|
|
||||||
def compute_loss(self, model, inputs):
|
class MultilabelTrainer(Trainer):
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs[0]
|
logits = outputs.logits
|
||||||
return my_custom_loss(logits, labels)
|
loss_fct = torch.nn.BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
|
||||||
|
labels.float().view(-1, self.model.config.num_labels))
|
||||||
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
|
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
|
||||||
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
|
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
|
||||||
|
|||||||
Reference in New Issue
Block a user