Better typing for num_items_in_batch (#38728)

* fix

* style

* type checking ?

* maybe this ?

* fix

* can't be an int anymore

* fix
This commit is contained in:
Marc Sun
2025-06-11 16:26:41 +02:00
committed by GitHub
parent 84710a4291
commit 11ad9be153
5 changed files with 47 additions and 18 deletions

View File

@@ -187,14 +187,17 @@ from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
labels = inputs.pop("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss for 3 labels with different weights
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
reduction = "mean" if num_items_in_batch is not None else "sum"
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
if num_items_in_batch is not None:
loss = loss / num_items_in_batch
return (loss, outputs) if return_outputs else loss
```