Update Trainer code example (#15070)
* Update code example * Fix code quality * Add comment
This commit is contained in:
@@ -47,22 +47,22 @@ when you use it on other models. When using it on your own model, make sure:
|
|||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Here is an example of how to customize [`Trainer`] using a custom loss function for multi-label classification:
|
Here is an example of how to customize [`Trainer`] to use a weighted loss (useful when you have an unbalanced training set):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
|
|
||||||
class MultilabelTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
labels = inputs.get("labels")
|
labels = inputs.get("labels")
|
||||||
|
# forward pass
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
loss_fct = nn.BCEWithLogitsLoss()
|
# compute custom loss (suppose one has 3 labels with different weights)
|
||||||
loss = loss_fct(
|
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
|
||||||
logits.view(-1, self.model.config.num_labels), labels.float().view(-1, self.model.config.num_labels)
|
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||||
)
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user