From 6c9010ef63da1570e5a651a05bb00855b7075514 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 10 Mar 2022 10:20:37 +0100 Subject: [PATCH] Update README.md --- examples/research_projects/jax-projects/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/jax-projects/README.md b/examples/research_projects/jax-projects/README.md index ed162db406..2a6449f348 100644 --- a/examples/research_projects/jax-projects/README.md +++ b/examples/research_projects/jax-projects/README.md @@ -780,7 +780,8 @@ def cross_entropy(logits, labels): # define a function which will run the forward pass return loss def compute_loss(params, input_ids, labels): logits = model(input_ids, params=params, train=True) - loss = cross_entropy(logits, onehot(labels)).mean() + num_classes = logits.shape[-1] + loss = cross_entropy(logits, onehot(labels, num_classes)).mean() return loss ```