Generalize decay_mask_fn to apply mask to all LayerNorm params (#18273)
* generalize decay_mask_fn to find all layernorm params * fixup * generalising decay_mask_fn
This commit is contained in:
@@ -875,15 +875,19 @@ def main():
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
# Note that this mask is specifically adapted for FlaxBart.
|
||||
# For FlaxT5, one should correct the layer norm parameter naming
|
||||
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
layer_norm_params = [
|
||||
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
|
||||
]
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = set(
|
||||
[
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
]
|
||||
)
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
|
||||
Reference in New Issue
Block a user