[Docs] Fix typo in CustomTrainer compute_loss method and adjust loss reduction logic (#39391)

Fix typo in CustomTrainer compute_loss method and adjust loss reduction logic
This commit is contained in:
MilkClouds
2025-07-15 01:25:06 +09:00
committed by GitHub
parent 3a48e9534c
commit 3635415af2

View File

@@ -187,13 +187,13 @@ from torch import nn
from transformers import Trainer from transformers import Trainer
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None): def compute_loss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
labels = inputs.pop("labels") labels = inputs.pop("labels")
# forward pass # forward pass
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs.get("logits") logits = outputs.get("logits")
# compute custom loss for 3 labels with different weights # compute custom loss for 3 labels with different weights
reduction = "mean" if num_items_in_batch is not None else "sum" reduction = "sum" if num_items_in_batch is not None else "mean"
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction)) loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
if num_items_in_batch is not None: if num_items_in_batch is not None: