Update README.md
This commit is contained in:
@@ -780,7 +780,8 @@ def cross_entropy(logits, labels):
|
||||
# define a function which will run the forward pass return loss
|
||||
def compute_loss(params, input_ids, labels):
|
||||
logits = model(input_ids, params=params, train=True)
|
||||
loss = cross_entropy(logits, onehot(labels)).mean()
|
||||
num_classes = logits.shape[-1]
|
||||
loss = cross_entropy(logits, onehot(labels, num_classes)).mean()
|
||||
return loss
|
||||
```
|
||||
|
||||
|
||||
Reference in New Issue
Block a user