Update README.md

This commit is contained in:
Sanchit Gandhi
2022-03-10 10:20:37 +01:00
committed by GitHub
parent fde901877a
commit 6c9010ef63

View File

@@ -780,7 +780,8 @@ def cross_entropy(logits, labels):
# define a function which will run the forward pass return loss
def compute_loss(params, input_ids, labels):
logits = model(input_ids, params=params, train=True)
loss = cross_entropy(logits, onehot(labels)).mean()
num_classes = logits.shape[-1]
loss = cross_entropy(logits, onehot(labels, num_classes)).mean()
return loss
```