From 4674061b2aef7592529a831582b5135e7a23a183 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Thu, 3 Jun 2021 06:35:26 -0400 Subject: [PATCH] Fix weight decay masking in `run_flax_glue.py` (#11964) * Fix weight decay masking in `run_flax_glue.py` Issues with the previous implementation: - The `dict` from `traverse_util.flatten_dict` has keys which are tuples of strings, not one long string with the path separated by periods. - `optax.masked` applies the transformation wherever the mask is True, so the masks are flipped. - Flax's LayerNorm calls the scale parameter `scale` not `weight` * Fix formatting with black * adapt results Co-authored-by: Patrick von Platen --- examples/flax/text-classification/README.md | 18 ++++++------ .../flax/text-classification/run_flax_glue.py | 28 +++++++------------ 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/examples/flax/text-classification/README.md b/examples/flax/text-classification/README.md index 50b4fd2f5d..c7dd12d3d2 100644 --- a/examples/flax/text-classification/README.md +++ b/examples/flax/text-classification/README.md @@ -63,15 +63,15 @@ In the Tensorboard results linked below, the random seed of each model is equal | Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics | |-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------| -| CoLA | Matthew's corr | 60.82 | 59.04 | 1.17 | [tfhub.dev](https://tensorboard.dev/experiment/U2ncNFP3RpWW6YnA9PYJBA/) | -| SST-2 | Accuracy | 92.43 | 92.13 | 0.38 | [tfhub.dev](https://tensorboard.dev/experiment/vzxoOHZURcm0rO1I33x7uA/) | -| MRPC | F1/Accuracy | 89.90/88.98 | 88.98/85.30 | 0.73/2.33 | [tfhub.dev](https://tensorboard.dev/experiment/EWPBIbfYSDGHjiYxrw2a2Q/) | -| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/3aYHKL10TeiaZYwH1M8ogA/) | -| QQP | Accuracy/F1 | 90.82/87.54 | 90.75/87.53 | 0.06/0.02 | [tfhub.dev](https://tensorboard.dev/experiment/VfVDLS4AQnqr4NMbng6yUw/) | -| MNLI | Matched acc. | 84.10 | 83.84 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/Sz9UdhoORaaSjzuOHRB4Jw/) | -| QNLI | Accuracy | 91.07 | 90.83 | 0.19 | [tfhub.dev](https://tensorboard.dev/experiment/zk6udb5MQAyAQ4eczrFBaQ/) | -| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/BwxaUoAEQ5aa3oQilEjADw/) | -| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/b2Y8ouwMTRC8iBWzRzVYTA/) | +| CoLA | Matthew's corr | 60.57 | 59.04 | 1.06 | [tfhub.dev](https://tensorboard.dev/experiment/lfr2adVpRtmLDALKrElkzg/) | +| SST-2 | Accuracy | 92.66 | 92.23 | 0.57 | [tfhub.dev](https://tensorboard.dev/experiment/jYvfv2trRHKMjoWnXVwrZA/) | +| MRPC | F1/Accuracy | 89.90/85.78 | 88.97/84.36 | 0.72/1.09 | [tfhub.dev](https://tensorboard.dev/experiment/bo3W3DEoRw2Q7YXjWrJkfg/) | +| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/fxVwbLD7QpKhbot0r9rn2w/) | +| QQP | Accuracy/F1 | 90.81/87.58 | 90.76/87.51 | 0.05/0.06 | [tfhub.dev](https://tensorboard.dev/experiment/di089Rc9TZmsnKRMrYNLsA/) | +| MNLI | Matched acc. | 84.10 | 83.80 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/JgNCGHDJSRaW6HBx6YQFYQ/) | +| QNLI | Accuracy | 91.01 | 90.82 | 0.17 | [tfhub.dev](https://tensorboard.dev/experiment/Bq7cMGJnQMSggYgL8qNGeQ/) | +| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/66Eq24bhRjqN6CEhgDSGqQ/) | +| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/TAqcnddqTkWvVEeGaWwIdQ/) | Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 899cdbd9b1..14862f7726 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -165,25 +165,17 @@ def create_train_state( logits_fn: Callable = struct.field(pytree_node=False) loss_fn: Callable = struct.field(pytree_node=False) - # Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers. - def adamw(decay): - return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=decay) + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) - def traverse(fn): - def mask(data): - flat = traverse_util.flatten_dict(data) - return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()}) - - return mask - - # We use Optax's "masking" functionality to create a multi-optimizer, one - # with weight decay and the other without. Note masking means the optimizer - # will ignore these paths. - decay_path = lambda p: not any(x in p for x in ["bias", "LayerNorm.weight"]) # noqa: E731 - - tx = optax.chain( - optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))), - optax.masked(adamw(weight_decay), mask=traverse(lambda path, _: not decay_path(path))), + tx = optax.adamw( + learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn ) if is_regression: