consistent nn. and nn.functional: part 5 docs (#12161)
This commit is contained in:
@@ -59,7 +59,7 @@ classification:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Trainer
|
||||
|
||||
class MultilabelTrainer(Trainer):
|
||||
@@ -67,7 +67,7 @@ classification:
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
loss_fct = torch.nn.BCEWithLogitsLoss()
|
||||
loss_fct = nn.BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
|
||||
labels.float().view(-1, self.model.config.num_labels))
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
Reference in New Issue
Block a user