pass decay_mask fn to optimizer (#12087)
This commit is contained in:
@@ -38,7 +38,7 @@ import flax
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
from flax import jax_utils
|
from flax import jax_utils, traverse_util
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -504,6 +504,15 @@ if __name__ == "__main__":
|
|||||||
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||||||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
|
# mask boolean with the same structure as the parameters.
|
||||||
|
# The mask is True for parameters that should be decayed.
|
||||||
|
def decay_mask_fn(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||||
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
# create adam optimizer
|
# create adam optimizer
|
||||||
adamw = optax.adamw(
|
adamw = optax.adamw(
|
||||||
learning_rate=linear_decay_lr_schedule_fn,
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
@@ -511,6 +520,7 @@ if __name__ == "__main__":
|
|||||||
b2=training_args.adam_beta2,
|
b2=training_args.adam_beta2,
|
||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
weight_decay=training_args.weight_decay,
|
weight_decay=training_args.weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup train state
|
# Setup train state
|
||||||
|
|||||||
Reference in New Issue
Block a user