Better typing for num_items_in_batch (#38728)
* fix * style * type checking ? * maybe this ? * fix * can't be an int anymore * fix
This commit is contained in:
@@ -187,14 +187,17 @@ from torch import nn
|
||||
from transformers import Trainer
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits")
|
||||
# compute custom loss for 3 labels with different weights
|
||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
|
||||
reduction = "mean" if num_items_in_batch is not None else "sum"
|
||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
|
||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
```
|
||||
|
||||
|
||||
Reference in New Issue
Block a user