diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index 55b308a74e..e6c2c27222 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -21,12 +21,25 @@ previous features. To inject custom behavior you can subclass them and override - **setup_wandb** -- Setups wandb (see `here `__ for more information). - **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at init. +- **compute_loss** - Computes the loss on a batch of training inputs. - **training_step** -- Performs a training step. - **prediction_step** -- Performs an evaluation/test step. - **run_model** (TensorFlow only) -- Basic pass through the model. - **evaluate** -- Runs an evaluation loop and returns metrics. - **predict** -- Returns predictions (with metrics if labels are available) on a test set. +Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function: + +.. code-block:: python + + from transformers import Trainer + class MyTrainer(Trainer): + def compute_loss(self, model, inputs): + labels = inputs.pop("labels") + outputs = models(**inputs) + logits = outputs[0] + return my_custom_loss(logits, labels) + ``Trainer`` ~~~~~~~~~~~ diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a981ff9f6d..e13087d60a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1024,15 +1024,9 @@ class Trainer: if self.args.fp16 and _use_native_amp: with autocast(): - outputs = model(**inputs) - loss = outputs[0] + loss = self.compute_loss(model, inputs) else: - outputs = model(**inputs) - # We don't use .loss here since the model may return tuples instead of ModelOutput. - loss = outputs[0] - - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index] + loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training @@ -1050,6 +1044,19 @@ class Trainer: return loss.detach() + def compute_loss(self, model, inputs): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + outputs = model(**inputs) + # Save past state if it exists + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + # We don't use .loss here since the model may return tuples instead of ModelOutput. + return outputs[0] + def is_local_master(self) -> bool: """ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on