[Flax] Align GLUE training script with mlm training script (#11778)
* speed up flax glue * remove unnecessary line * remove folder * remove run in loop Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
223943872e
commit
bd9871657b
@@ -59,20 +59,19 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is
|
||||
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
|
||||
are undertrained and we could get better results when training longer.
|
||||
|
||||
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing).
|
||||
|
||||
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1p3XzReMO75m_XdEJvPue-PIq_PN-96J2IJpJW1yS-10/edit?usp=sharing).
|
||||
|
||||
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|
||||
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
|
||||
| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) |
|
||||
| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) |
|
||||
| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) |
|
||||
| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) |
|
||||
| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) |
|
||||
| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) |
|
||||
| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) |
|
||||
| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) |
|
||||
| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) |
|
||||
| CoLA | Matthew's corr | 60.82 | 59.04 | 1.17 | [tfhub.dev](https://tensorboard.dev/experiment/U2ncNFP3RpWW6YnA9PYJBA/) |
|
||||
| SST-2 | Accuracy | 92.43 | 92.13 | 0.38 | [tfhub.dev](https://tensorboard.dev/experiment/vzxoOHZURcm0rO1I33x7uA/) |
|
||||
| MRPC | F1/Accuracy | 89.90/88.98 | 88.98/85.30 | 0.73/2.33 | [tfhub.dev](https://tensorboard.dev/experiment/EWPBIbfYSDGHjiYxrw2a2Q/) |
|
||||
| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/3aYHKL10TeiaZYwH1M8ogA/) |
|
||||
| QQP | Accuracy/F1 | 90.82/87.54 | 90.75/87.53 | 0.06/0.02 | [tfhub.dev](https://tensorboard.dev/experiment/VfVDLS4AQnqr4NMbng6yUw/) |
|
||||
| MNLI | Matched acc. | 84.10 | 83.84 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/Sz9UdhoORaaSjzuOHRB4Jw/) |
|
||||
| QNLI | Accuracy | 91.07 | 90.83 | 0.19 | [tfhub.dev](https://tensorboard.dev/experiment/zk6udb5MQAyAQ4eczrFBaQ/) |
|
||||
| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/BwxaUoAEQ5aa3oQilEjADw/) |
|
||||
| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/b2Y8ouwMTRC8iBWzRzVYTA/) |
|
||||
|
||||
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
|
||||
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
|
||||
@@ -85,18 +84,18 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https
|
||||
|
||||
| Task | TPU v3-8 | 8 GPU | [1 GPU](https://tensorboard.dev/experiment/mkPS4Zh8TnGe1HB6Yzwj4Q) | 1 GPU (Pytorch) |
|
||||
|-------|-----------|------------|------------|-----------------|
|
||||
| CoLA | 1m 46s | 1m 26s | 3m 9s | 4m 6s |
|
||||
| SST-2 | 5m 30s | 6m 28s | 22m 33s | 34m 37s |
|
||||
| MRPC | 1m 32s | 1m 14s | 2m 20s | 2m 56s |
|
||||
| STS-B | 1m 33s | 1m 12s | 2m 16s | 2m 48s |
|
||||
| QQP | 24m 40s | 31m 48s | 1h 59m 41s | 2h 54m |
|
||||
| MNLI | 26m 30s | 33m 55s | 2h 9m 37s | 3h 7m 6s |
|
||||
| QNLI | 8m | 9m 40s | 34m 40s | 49m 8s |
|
||||
| RTE | 1m 21s | 55s | 1m 10s | 1m 16s |
|
||||
| WNLI | 1m 12s | 48s | 39s | 36s |
|
||||
| CoLA | 1m 42s | 1m 26s | 3m 9s | 4m 6s |
|
||||
| SST-2 | 5m 12s | 6m 28s | 22m 33s | 34m 37s |
|
||||
| MRPC | 1m 29s | 1m 14s | 2m 20s | 2m 56s |
|
||||
| STS-B | 1m 30s | 1m 12s | 2m 16s | 2m 48s |
|
||||
| QQP | 22m 50s | 31m 48s | 1h 59m 41s | 2h 54m |
|
||||
| MNLI | 25m 03s | 33m 55s | 2h 9m 37s | 3h 7m 6s |
|
||||
| QNLI | 7m30s | 9m 40s | 34m 40s | 49m 8s |
|
||||
| RTE | 1m 20s | 55s | 1m 10s | 1m 16s |
|
||||
| WNLI | 1m 11s | 48s | 39s | 36s |
|
||||
|-------|
|
||||
| **TOTAL** | 1h 13m | 1h 28m | 5h 16m | 6h 37m |
|
||||
| **COST*** | $9.60 | $29.10 | $13.06 | $16.41 |
|
||||
| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
|
||||
| **COST*** | $8.56 | $29.10 | $13.06 | $16.41 |
|
||||
|
||||
|
||||
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
||||
@@ -106,4 +105,4 @@ the following tables:
|
||||
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
|
||||
V100 GPU). GPU experiments are ran without further optimizations besides JAX
|
||||
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
||||
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
||||
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
||||
|
||||
Reference in New Issue
Block a user