From d1500d9151ae1b728aa3561f035d96f139d0f5ca Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 9 Jun 2021 23:19:27 +0530 Subject: [PATCH] pass decay_mask fn to optimizer (#12087) --- examples/flax/language-modeling/run_mlm_flax.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index dddd6ce478..ff38b0090e 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -38,7 +38,7 @@ import flax import jax import jax.numpy as jnp import optax -from flax import jax_utils +from flax import jax_utils, traverse_util from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from transformers import ( @@ -504,6 +504,15 @@ if __name__ == "__main__": schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps] ) + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, @@ -511,6 +520,7 @@ if __name__ == "__main__": b2=training_args.adam_beta2, eps=1e-8, weight_decay=training_args.weight_decay, + mask=decay_mask_fn, ) # Setup train state