From 1ab147d648defff2071de253a3dbc1b1c4d24e4d Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Mon, 31 May 2021 10:29:04 -0400 Subject: [PATCH] Remove redundant `nn.log_softmax` in `run_flax_glue.py` (#11920) * Remove redundant `nn.log_softmax` in `run_flax_glue.py` `optax.softmax_cross_entropy` expects unnormalized logits, and so it already calls `nn.log_softmax`, so I believe it is not needed here. `nn.log_softmax` is idempotent so mathematically it shouldn't have made a difference. * Remove unused 'flax.linen' import --- examples/flax/text-classification/run_flax_glue.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 24aac7defd..899cdbd9b1 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -29,7 +29,6 @@ import jax import jax.numpy as jnp import optax import transformers -from flax import linen as nn from flax import struct, traverse_util from flax.jax_utils import replicate, unreplicate from flax.metrics import tensorboard @@ -202,7 +201,6 @@ def create_train_state( else: # Classification. def cross_entropy_loss(logits, labels): - logits = nn.log_softmax(logits) xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels)) return jnp.mean(xentropy)