From 80f72960913ab6682451c33dfa8035ef0c932128 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Wed, 19 Jan 2022 20:15:12 +0100 Subject: [PATCH] Update Trainer code example (#15070) * Update code example * Fix code quality * Add comment --- docs/source/main_classes/trainer.mdx | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/main_classes/trainer.mdx b/docs/source/main_classes/trainer.mdx index a193b40ac8..b40cb9ee18 100644 --- a/docs/source/main_classes/trainer.mdx +++ b/docs/source/main_classes/trainer.mdx @@ -47,22 +47,22 @@ when you use it on other models. When using it on your own model, make sure: -Here is an example of how to customize [`Trainer`] using a custom loss function for multi-label classification: +Here is an example of how to customize [`Trainer`] to use a weighted loss (useful when you have an unbalanced training set): ```python from torch import nn from transformers import Trainer -class MultilabelTrainer(Trainer): +class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.get("labels") + # forward pass outputs = model(**inputs) logits = outputs.get("logits") - loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct( - logits.view(-1, self.model.config.num_labels), labels.float().view(-1, self.model.config.num_labels) - ) + # compute custom loss (suppose one has 3 labels with different weights) + loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0])) + loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss ```