consistent nn. and nn.functional: part 5 docs (#12161)

This commit is contained in:
Stas Bekman
2021-06-14 13:34:32 -07:00
committed by GitHub
parent 88e84186e5
commit 040283170c
5 changed files with 9 additions and 9 deletions

View File

@@ -59,7 +59,7 @@ classification:
.. code-block:: python
import torch
from torch import nn
from transformers import Trainer
class MultilabelTrainer(Trainer):
@@ -67,7 +67,7 @@ classification:
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
loss_fct = torch.nn.BCEWithLogitsLoss()
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
labels.float().view(-1, self.model.config.num_labels))
return (loss, outputs) if return_outputs else loss