Compute loss method (#7074)
This commit is contained in:
@@ -21,12 +21,25 @@ previous features. To inject custom behavior you can subclass them and override
|
|||||||
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
|
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
|
||||||
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
||||||
init.
|
init.
|
||||||
|
- **compute_loss** - Computes the loss on a batch of training inputs.
|
||||||
- **training_step** -- Performs a training step.
|
- **training_step** -- Performs a training step.
|
||||||
- **prediction_step** -- Performs an evaluation/test step.
|
- **prediction_step** -- Performs an evaluation/test step.
|
||||||
- **run_model** (TensorFlow only) -- Basic pass through the model.
|
- **run_model** (TensorFlow only) -- Basic pass through the model.
|
||||||
- **evaluate** -- Runs an evaluation loop and returns metrics.
|
- **evaluate** -- Runs an evaluation loop and returns metrics.
|
||||||
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
|
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
|
||||||
|
|
||||||
|
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
class MyTrainer(Trainer):
|
||||||
|
def compute_loss(self, model, inputs):
|
||||||
|
labels = inputs.pop("labels")
|
||||||
|
outputs = models(**inputs)
|
||||||
|
logits = outputs[0]
|
||||||
|
return my_custom_loss(logits, labels)
|
||||||
|
|
||||||
|
|
||||||
``Trainer``
|
``Trainer``
|
||||||
~~~~~~~~~~~
|
~~~~~~~~~~~
|
||||||
|
|||||||
@@ -1024,15 +1024,9 @@ class Trainer:
|
|||||||
|
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.args.fp16 and _use_native_amp:
|
||||||
with autocast():
|
with autocast():
|
||||||
outputs = model(**inputs)
|
loss = self.compute_loss(model, inputs)
|
||||||
loss = outputs[0]
|
|
||||||
else:
|
else:
|
||||||
outputs = model(**inputs)
|
loss = self.compute_loss(model, inputs)
|
||||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
|
||||||
loss = outputs[0]
|
|
||||||
|
|
||||||
if self.args.past_index >= 0:
|
|
||||||
self._past = outputs[self.args.past_index]
|
|
||||||
|
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
@@ -1050,6 +1044,19 @@ class Trainer:
|
|||||||
|
|
||||||
return loss.detach()
|
return loss.detach()
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs):
|
||||||
|
"""
|
||||||
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||||
|
|
||||||
|
Subclass and override for custom behavior.
|
||||||
|
"""
|
||||||
|
outputs = model(**inputs)
|
||||||
|
# Save past state if it exists
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = outputs[self.args.past_index]
|
||||||
|
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
def is_local_master(self) -> bool:
|
def is_local_master(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
||||||
|
|||||||
Reference in New Issue
Block a user