Update trainer.mdx class_weights example (#23787)

class_weights tensor should follow model's device
This commit is contained in:
amitportnoy
2023-05-26 15:36:33 +03:00
committed by GitHub
parent 4d9b76a80f
commit d61d747627

View File

@@ -61,7 +61,7 @@ class CustomTrainer(Trainer):
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs.get("logits") logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights) # compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0])) loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss
``` ```