Remove redundant nn.log_softmax in run_flax_glue.py (#11920)
* Remove redundant `nn.log_softmax` in `run_flax_glue.py` `optax.softmax_cross_entropy` expects unnormalized logits, and so it already calls `nn.log_softmax`, so I believe it is not needed here. `nn.log_softmax` is idempotent so mathematically it shouldn't have made a difference. * Remove unused 'flax.linen' import
This commit is contained in:
committed by
GitHub
parent
fb60c309c6
commit
1ab147d648
@@ -29,7 +29,6 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
import transformers
|
import transformers
|
||||||
from flax import linen as nn
|
|
||||||
from flax import struct, traverse_util
|
from flax import struct, traverse_util
|
||||||
from flax.jax_utils import replicate, unreplicate
|
from flax.jax_utils import replicate, unreplicate
|
||||||
from flax.metrics import tensorboard
|
from flax.metrics import tensorboard
|
||||||
@@ -202,7 +201,6 @@ def create_train_state(
|
|||||||
else: # Classification.
|
else: # Classification.
|
||||||
|
|
||||||
def cross_entropy_loss(logits, labels):
|
def cross_entropy_loss(logits, labels):
|
||||||
logits = nn.log_softmax(logits)
|
|
||||||
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
|
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
|
||||||
return jnp.mean(xentropy)
|
return jnp.mean(xentropy)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user