Update trainer.mdx class_weights example (#23787)
class_weights tensor should follow model's device
This commit is contained in:
@@ -61,7 +61,7 @@ class CustomTrainer(Trainer):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
# compute custom loss (suppose one has 3 labels with different weights)
|
# compute custom loss (suppose one has 3 labels with different weights)
|
||||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
|
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
|
||||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
```
|
```
|
||||||
|
|||||||
Reference in New Issue
Block a user