[Docs] Fix typo in CustomTrainer compute_loss method and adjust loss reduction logic (#39391)
Fix typo in CustomTrainer compute_loss method and adjust loss reduction logic
This commit is contained in:
@@ -187,13 +187,13 @@ from torch import nn
|
||||
from transformers import Trainer
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
|
||||
def compute_loss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits")
|
||||
# compute custom loss for 3 labels with different weights
|
||||
reduction = "mean" if num_items_in_batch is not None else "sum"
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
|
||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||
if num_items_in_batch is not None:
|
||||
|
||||
Reference in New Issue
Block a user