Update Trainer code example (#15070)
* Update code example * Fix code quality * Add comment
This commit is contained in:
@@ -47,22 +47,22 @@ when you use it on other models. When using it on your own model, make sure:
|
||||
|
||||
</Tip>
|
||||
|
||||
Here is an example of how to customize [`Trainer`] using a custom loss function for multi-label classification:
|
||||
Here is an example of how to customize [`Trainer`] to use a weighted loss (useful when you have an unbalanced training set):
|
||||
|
||||
```python
|
||||
from torch import nn
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class MultilabelTrainer(Trainer):
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.get("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits")
|
||||
loss_fct = nn.BCEWithLogitsLoss()
|
||||
loss = loss_fct(
|
||||
logits.view(-1, self.model.config.num_labels), labels.float().view(-1, self.model.config.num_labels)
|
||||
)
|
||||
# compute custom loss (suppose one has 3 labels with different weights)
|
||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
|
||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
```
|
||||
|
||||
|
||||
Reference in New Issue
Block a user