From d61d7476275a6f31fcf68df334c68ac6bf739166 Mon Sep 17 00:00:00 2001 From: amitportnoy <113588658+amitportnoy@users.noreply.github.com> Date: Fri, 26 May 2023 15:36:33 +0300 Subject: [PATCH] Update trainer.mdx class_weights example (#23787) class_weights tensor should follow model's device --- docs/source/en/main_classes/trainer.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/main_classes/trainer.mdx b/docs/source/en/main_classes/trainer.mdx index 67ab6aba42..409a6c6d33 100644 --- a/docs/source/en/main_classes/trainer.mdx +++ b/docs/source/en/main_classes/trainer.mdx @@ -61,7 +61,7 @@ class CustomTrainer(Trainer): outputs = model(**inputs) logits = outputs.get("logits") # compute custom loss (suppose one has 3 labels with different weights) - loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0])) + loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss ```