Remove redundant nn.log_softmax in run_flax_glue.py (#11920)

* Remove redundant `nn.log_softmax` in `run_flax_glue.py`

`optax.softmax_cross_entropy` expects unnormalized logits, and so it already calls `nn.log_softmax`, so I believe it is not needed here. `nn.log_softmax` is idempotent so mathematically it shouldn't have made a difference.

* Remove unused 'flax.linen' import
This commit is contained in:
Nicholas Vadivelu
2021-05-31 10:29:04 -04:00
committed by GitHub
parent fb60c309c6
commit 1ab147d648

View File

@@ -29,7 +29,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
import transformers import transformers
from flax import linen as nn
from flax import struct, traverse_util from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard from flax.metrics import tensorboard
@@ -202,7 +201,6 @@ def create_train_state(
else: # Classification. else: # Classification.
def cross_entropy_loss(logits, labels): def cross_entropy_loss(logits, labels):
logits = nn.log_softmax(logits)
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels)) xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
return jnp.mean(xentropy) return jnp.mean(xentropy)