Update README.md
This commit is contained in:
@@ -780,7 +780,8 @@ def cross_entropy(logits, labels):
|
|||||||
# define a function which will run the forward pass return loss
|
# define a function which will run the forward pass return loss
|
||||||
def compute_loss(params, input_ids, labels):
|
def compute_loss(params, input_ids, labels):
|
||||||
logits = model(input_ids, params=params, train=True)
|
logits = model(input_ids, params=params, train=True)
|
||||||
loss = cross_entropy(logits, onehot(labels)).mean()
|
num_classes = logits.shape[-1]
|
||||||
|
loss = cross_entropy(logits, onehot(labels, num_classes)).mean()
|
||||||
return loss
|
return loss
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user