diff --git a/docs/source/training.rst b/docs/source/training.rst index 9a3e510583..524818b602 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -109,9 +109,9 @@ The following is equivalent to the previous example: .. code-block:: python from torch.nn import functional as F - labels = torch.tensor([1,0]).unsqueeze(0) + labels = torch.tensor([1,0]) outputs = model(input_ids, attention_mask=attention_mask) - loss = F.cross_entropy(labels, outputs.logitd) + loss = F.cross_entropy(outputs.logits, labels) loss.backward() optimizer.step()