From 3635415af2b6f6e0d6345849f07ba6da9af4206a Mon Sep 17 00:00:00 2001 From: MilkClouds Date: Tue, 15 Jul 2025 01:25:06 +0900 Subject: [PATCH] [Docs] Fix typo in CustomTrainer compute_loss method and adjust loss reduction logic (#39391) Fix typo in CustomTrainer compute_loss method and adjust loss reduction logic --- docs/source/en/trainer.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 3572fb4385..48325da689 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -187,13 +187,13 @@ from torch import nn from transformers import Trainer class CustomTrainer(Trainer): - def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None): + def compute_loss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None): labels = inputs.pop("labels") # forward pass outputs = model(**inputs) logits = outputs.get("logits") # compute custom loss for 3 labels with different weights - reduction = "mean" if num_items_in_batch is not None else "sum" + reduction = "sum" if num_items_in_batch is not None else "mean" loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction)) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) if num_items_in_batch is not None: