From 26a33cfd8c2d6923f41ab98683f33172e8948ff3 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 10 Mar 2021 14:58:22 -0500 Subject: [PATCH] Document Trainer limitation on custom models (#10635) --- docs/source/main_classes/trainer.rst | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index 4c3bc64f03..a7e3134eab 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -21,7 +21,7 @@ Before instantiating your :class:`~transformers.Trainer`/:class:`~transformers.T customization during training. The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex -`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow. +`__ and Native AMP for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow. Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop which supports the above features. To inject custom behavior you can subclass them and override the following methods: @@ -39,6 +39,18 @@ the above features. To inject custom behavior you can subclass them and override - **evaluate** -- Runs an evaluation loop and returns metrics. - **predict** -- Returns predictions (with metrics if labels are available) on a test set. +.. warning:: + + The :class:`~transformers.Trainer` class is optimized for 🤗 Transformers models and can have surprising behaviors + when you use it on other models. When using it on your own model, make sure: + + - your model always return tuples or subclasses of :class:`~transformers.file_utils.ModelOutput`. + - your model can compute the loss if a :obj:`labels` argument is provided and that loss is returned as the first + element of the tuple (if your model returns tuples) + - your model can accept multiple label arguments (use the :obj:`label_names` in your + :class:`~transformers.TrainingArguments` to indicate their name to the :class:`~transformers.Trainer`) but none + of them should be named :obj:`"label"`. + Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function for multi-label classification: