diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index 149d3abff5..a4deab8041 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -875,15 +875,19 @@ def main(): # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. - # Note that this mask is specifically adapted for FlaxBart. - # For FlaxT5, one should correct the layer norm parameter naming - # accordingly - see `run_t5_mlm_flax.py` e.g. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - layer_norm_params = [ - (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] - ] - flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = set( + [ + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + ] + ) + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index ee40d80cab..00fc6e61f7 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -638,15 +638,19 @@ def main(): # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. - # Note that this mask is specifically adapted for FlaxGPT2. - # For other models, one should correct the layer norm parameter naming - # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = { - path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) - for path in flat_params - } + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = set( + [ + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + ] + ) + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 0343d02341..9657471246 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -658,12 +658,19 @@ def main(): # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. - # Note that this mask is specifically adapted for FlaxBERT-like models. - # For other models, one should correct the layer norm parameter naming - # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = set( + [ + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + ] + ) + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index c9ab71c2b9..ad0b43d3d6 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -326,7 +326,6 @@ class FlaxDataCollatorForT5MLM: decoder_start_token_id: int def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: - # convert list to dict and tensorize input batch = BatchEncoding( {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} @@ -395,7 +394,6 @@ class FlaxDataCollatorForT5MLM: return input_ids def random_spans_noise_mask(self, length): - """This function is copy of `random_spans_helper `__ . Noise mask consisting of random spans of noise tokens. @@ -782,10 +780,17 @@ def main(): # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = { - path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) - for path in flat_params - } + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = set( + [ + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + ] + ) + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index 740ad93966..b424756355 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -327,12 +327,19 @@ def create_train_state( # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. - # Note that this mask is specifically adapted for FlaxBERT-like models. - # For other models, one should correct the layer norm parameter naming - # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = set( + [ + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + ] + ) + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) tx = optax.adamw( diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index 089b3c1911..a1b5fc37e2 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -723,15 +723,19 @@ def main(): # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. - # Note that this mask is specifically adapted for FlaxBart. - # For FlaxT5, one should correct the layer norm parameter naming - # accordingly - see `run_t5_mlm_flax.py` e.g. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - layer_norm_params = [ - (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] - ] - flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = set( + [ + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + ] + ) + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 48ce733a23..d777123185 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -226,7 +226,17 @@ def create_train_state( # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = set( + [ + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + ] + ) + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) tx = optax.adamw( diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index 6b3a20ae27..682fc03b8b 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -284,12 +284,19 @@ def create_train_state( # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. - # Note that this mask is specifically adapted for FlaxBERT-like models. - # For other models, one should correct the layer norm parameter naming - # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) - flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = set( + [ + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + ] + ) + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) tx = optax.adamw(