fixed document (#13414)
This commit is contained in:
@@ -64,9 +64,9 @@ classification:
|
|||||||
|
|
||||||
class MultilabelTrainer(Trainer):
|
class MultilabelTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.get("labels")
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.logits
|
logits = outputs.get('logits')
|
||||||
loss_fct = nn.BCEWithLogitsLoss()
|
loss_fct = nn.BCEWithLogitsLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
|
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
|
||||||
labels.float().view(-1, self.model.config.num_labels))
|
labels.float().view(-1, self.model.config.num_labels))
|
||||||
|
|||||||
Reference in New Issue
Block a user