Update Trainer code example (#15070)

* Update code example

* Fix code quality

* Add comment
This commit is contained in:
NielsRogge
2022-01-19 20:15:12 +01:00
committed by GitHub
parent ac227093e4
commit 80f7296091

View File

@@ -47,22 +47,22 @@ when you use it on other models. When using it on your own model, make sure:
</Tip>
Here is an example of how to customize [`Trainer`] using a custom loss function for multi-label classification:
Here is an example of how to customize [`Trainer`] to use a weighted loss (useful when you have an unbalanced training set):
```python
from torch import nn
from transformers import Trainer
class MultilabelTrainer(Trainer):
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(
logits.view(-1, self.model.config.num_labels), labels.float().view(-1, self.model.config.num_labels)
)
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
```