Generalize decay_mask_fn to apply mask to all LayerNorm params (#18273)

* generalize decay_mask_fn to find all layernorm params

* fixup

* generalising decay_mask_fn
This commit is contained in:
Duong A. Nguyen
2022-07-27 18:23:57 +07:00
committed by GitHub
parent 83d2d74509
commit 170fcaa604
8 changed files with 88 additions and 40 deletions

View File

@@ -327,12 +327,19 @@ def create_train_state(
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBERT-like models.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw(