Fix weight decay masking in run_flax_glue.py (#11964)

* Fix weight decay masking in `run_flax_glue.py`

Issues with the previous implementation:
- The `dict` from `traverse_util.flatten_dict` has keys which are tuples of strings, not one long string with the path separated by periods.
- `optax.masked` applies the transformation wherever the mask is True, so the masks are flipped.
- Flax's LayerNorm calls the scale parameter `scale` not `weight`

* Fix formatting with black

* adapt results

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Nicholas Vadivelu
2021-06-03 06:35:26 -04:00
committed by GitHub
parent 61c5063491
commit 4674061b2a
2 changed files with 19 additions and 27 deletions

View File

@@ -63,15 +63,15 @@ In the Tensorboard results linked below, the random seed of each model is equal
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics | | Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------| |-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthew's corr | 60.82 | 59.04 | 1.17 | [tfhub.dev](https://tensorboard.dev/experiment/U2ncNFP3RpWW6YnA9PYJBA/) | | CoLA | Matthew's corr | 60.57 | 59.04 | 1.06 | [tfhub.dev](https://tensorboard.dev/experiment/lfr2adVpRtmLDALKrElkzg/) |
| SST-2 | Accuracy | 92.43 | 92.13 | 0.38 | [tfhub.dev](https://tensorboard.dev/experiment/vzxoOHZURcm0rO1I33x7uA/) | | SST-2 | Accuracy | 92.66 | 92.23 | 0.57 | [tfhub.dev](https://tensorboard.dev/experiment/jYvfv2trRHKMjoWnXVwrZA/) |
| MRPC | F1/Accuracy | 89.90/88.98 | 88.98/85.30 | 0.73/2.33 | [tfhub.dev](https://tensorboard.dev/experiment/EWPBIbfYSDGHjiYxrw2a2Q/) | | MRPC | F1/Accuracy | 89.90/85.78 | 88.97/84.36 | 0.72/1.09 | [tfhub.dev](https://tensorboard.dev/experiment/bo3W3DEoRw2Q7YXjWrJkfg/) |
| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/3aYHKL10TeiaZYwH1M8ogA/) | | STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/fxVwbLD7QpKhbot0r9rn2w/) |
| QQP | Accuracy/F1 | 90.82/87.54 | 90.75/87.53 | 0.06/0.02 | [tfhub.dev](https://tensorboard.dev/experiment/VfVDLS4AQnqr4NMbng6yUw/) | | QQP | Accuracy/F1 | 90.81/87.58 | 90.76/87.51 | 0.05/0.06 | [tfhub.dev](https://tensorboard.dev/experiment/di089Rc9TZmsnKRMrYNLsA/) |
| MNLI | Matched acc. | 84.10 | 83.84 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/Sz9UdhoORaaSjzuOHRB4Jw/) | | MNLI | Matched acc. | 84.10 | 83.80 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/JgNCGHDJSRaW6HBx6YQFYQ/) |
| QNLI | Accuracy | 91.07 | 90.83 | 0.19 | [tfhub.dev](https://tensorboard.dev/experiment/zk6udb5MQAyAQ4eczrFBaQ/) | | QNLI | Accuracy | 91.01 | 90.82 | 0.17 | [tfhub.dev](https://tensorboard.dev/experiment/Bq7cMGJnQMSggYgL8qNGeQ/) |
| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/BwxaUoAEQ5aa3oQilEjADw/) | | RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/66Eq24bhRjqN6CEhgDSGqQ/) |
| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/b2Y8ouwMTRC8iBWzRzVYTA/) | | WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/TAqcnddqTkWvVEeGaWwIdQ/) |
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.

View File

@@ -165,25 +165,17 @@ def create_train_state(
logits_fn: Callable = struct.field(pytree_node=False) logits_fn: Callable = struct.field(pytree_node=False)
loss_fn: Callable = struct.field(pytree_node=False) loss_fn: Callable = struct.field(pytree_node=False)
# Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers. # We use Optax's "masking" functionality to not apply weight decay
def adamw(decay): # to bias and LayerNorm scale parameters. decay_mask_fn returns a
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=decay) # mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
def traverse(fn): tx = optax.adamw(
def mask(data): learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn
flat = traverse_util.flatten_dict(data)
return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
return mask
# We use Optax's "masking" functionality to create a multi-optimizer, one
# with weight decay and the other without. Note masking means the optimizer
# will ignore these paths.
decay_path = lambda p: not any(x in p for x in ["bias", "LayerNorm.weight"]) # noqa: E731
tx = optax.chain(
optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))),
optax.masked(adamw(weight_decay), mask=traverse(lambda path, _: not decay_path(path))),
) )
if is_regression: if is_regression: