[Flax] Example scripts - correct weight decay (#12409)
* fix_torch_device_generate_test * remove @ * finish * finish * correct style
This commit is contained in:
committed by
GitHub
parent
aecae53377
commit
813328682e
@@ -477,9 +477,15 @@ def main():
|
|||||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
# mask boolean with the same structure as the parameters.
|
# mask boolean with the same structure as the parameters.
|
||||||
# The mask is True for parameters that should be decayed.
|
# The mask is True for parameters that should be decayed.
|
||||||
|
# Note that this mask is specifically adapted for FlaxGPT2.
|
||||||
|
# For other models, one should correct the layer norm parameter naming
|
||||||
|
# accordingly.
|
||||||
def decay_mask_fn(params):
|
def decay_mask_fn(params):
|
||||||
flat_params = traverse_util.flatten_dict(params)
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
flat_mask = {
|
||||||
|
path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
|
||||||
|
for path in flat_params
|
||||||
|
}
|
||||||
return traverse_util.unflatten_dict(flat_mask)
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
# create adam optimizer
|
# create adam optimizer
|
||||||
|
|||||||
@@ -508,6 +508,9 @@ if __name__ == "__main__":
|
|||||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
# mask boolean with the same structure as the parameters.
|
# mask boolean with the same structure as the parameters.
|
||||||
# The mask is True for parameters that should be decayed.
|
# The mask is True for parameters that should be decayed.
|
||||||
|
# Note that this mask is specifically adapted for FlaxBERT-like models.
|
||||||
|
# For other models, one should correct the layer norm parameter naming
|
||||||
|
# accordingly.
|
||||||
def decay_mask_fn(params):
|
def decay_mask_fn(params):
|
||||||
flat_params = traverse_util.flatten_dict(params)
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||||
|
|||||||
@@ -626,7 +626,10 @@ if __name__ == "__main__":
|
|||||||
# The mask is True for parameters that should be decayed.
|
# The mask is True for parameters that should be decayed.
|
||||||
def decay_mask_fn(params):
|
def decay_mask_fn(params):
|
||||||
flat_params = traverse_util.flatten_dict(params)
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
flat_mask = {
|
||||||
|
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
|
||||||
|
for path in flat_params
|
||||||
|
}
|
||||||
return traverse_util.unflatten_dict(flat_mask)
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
# create adam optimizer
|
# create adam optimizer
|
||||||
|
|||||||
@@ -578,9 +578,15 @@ def main():
|
|||||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
# mask boolean with the same structure as the parameters.
|
# mask boolean with the same structure as the parameters.
|
||||||
# The mask is True for parameters that should be decayed.
|
# The mask is True for parameters that should be decayed.
|
||||||
|
# Note that this mask is specifically adapted for FlaxBart.
|
||||||
|
# For FlaxT5, one should correct the layer norm parameter naming
|
||||||
|
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
||||||
def decay_mask_fn(params):
|
def decay_mask_fn(params):
|
||||||
flat_params = traverse_util.flatten_dict(params)
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
layer_norm_params = [
|
||||||
|
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
|
||||||
|
]
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
|
||||||
return traverse_util.unflatten_dict(flat_mask)
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
# create adam optimizer
|
# create adam optimizer
|
||||||
|
|||||||
Reference in New Issue
Block a user