[Docs] Fix typo in CustomTrainer compute_loss method and adjust loss reduction logic (#39391)
Fix typo in CustomTrainer compute_loss method and adjust loss reduction logic
This commit is contained in:
@@ -187,13 +187,13 @@ from torch import nn
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
|
def compute_loss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
# compute custom loss for 3 labels with different weights
|
# compute custom loss for 3 labels with different weights
|
||||||
reduction = "mean" if num_items_in_batch is not None else "sum"
|
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
|
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
|
||||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||||
if num_items_in_batch is not None:
|
if num_items_in_batch is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user