From 813328682e2a6822ff8d0fde30a7ed9012449daf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Jun 2021 12:01:08 +0100 Subject: [PATCH] [Flax] Example scripts - correct weight decay (#12409) * fix_torch_device_generate_test * remove @ * finish * finish * correct style --- examples/flax/language-modeling/run_clm_flax.py | 8 +++++++- examples/flax/language-modeling/run_mlm_flax.py | 3 +++ examples/flax/language-modeling/run_t5_mlm_flax.py | 5 ++++- examples/flax/summarization/run_summarization_flax.py | 8 +++++++- 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index c313ad0b3a..e664e5718a 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -477,9 +477,15 @@ def main(): # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxGPT2. + # For other models, one should correct the layer norm parameter naming + # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + flat_mask = { + path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) + for path in flat_params + } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 945cd4eb65..e3058c4ca7 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -508,6 +508,9 @@ if __name__ == "__main__": # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxBERT-like models. + # For other models, one should correct the layer norm parameter naming + # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index c79304ec2c..49f4cf1d79 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -626,7 +626,10 @@ if __name__ == "__main__": # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + flat_mask = { + path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) + for path in flat_params + } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index 3abefc1d1e..636fa3bb85 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -578,9 +578,15 @@ def main(): # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + layer_norm_params = [ + (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer