[Flax] Align GLUE training script with mlm training script (#11778)
* speed up flax glue * remove unnecessary line * remove folder * remove run in loop Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
223943872e
commit
bd9871657b
@@ -59,20 +59,19 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is
|
|||||||
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
|
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
|
||||||
are undertrained and we could get better results when training longer.
|
are undertrained and we could get better results when training longer.
|
||||||
|
|
||||||
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing).
|
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1p3XzReMO75m_XdEJvPue-PIq_PN-96J2IJpJW1yS-10/edit?usp=sharing).
|
||||||
|
|
||||||
|
|
||||||
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|
||||||
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
|
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
|
||||||
| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) |
|
| CoLA | Matthew's corr | 60.82 | 59.04 | 1.17 | [tfhub.dev](https://tensorboard.dev/experiment/U2ncNFP3RpWW6YnA9PYJBA/) |
|
||||||
| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) |
|
| SST-2 | Accuracy | 92.43 | 92.13 | 0.38 | [tfhub.dev](https://tensorboard.dev/experiment/vzxoOHZURcm0rO1I33x7uA/) |
|
||||||
| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) |
|
| MRPC | F1/Accuracy | 89.90/88.98 | 88.98/85.30 | 0.73/2.33 | [tfhub.dev](https://tensorboard.dev/experiment/EWPBIbfYSDGHjiYxrw2a2Q/) |
|
||||||
| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) |
|
| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/3aYHKL10TeiaZYwH1M8ogA/) |
|
||||||
| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) |
|
| QQP | Accuracy/F1 | 90.82/87.54 | 90.75/87.53 | 0.06/0.02 | [tfhub.dev](https://tensorboard.dev/experiment/VfVDLS4AQnqr4NMbng6yUw/) |
|
||||||
| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) |
|
| MNLI | Matched acc. | 84.10 | 83.84 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/Sz9UdhoORaaSjzuOHRB4Jw/) |
|
||||||
| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) |
|
| QNLI | Accuracy | 91.07 | 90.83 | 0.19 | [tfhub.dev](https://tensorboard.dev/experiment/zk6udb5MQAyAQ4eczrFBaQ/) |
|
||||||
| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) |
|
| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/BwxaUoAEQ5aa3oQilEjADw/) |
|
||||||
| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) |
|
| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/b2Y8ouwMTRC8iBWzRzVYTA/) |
|
||||||
|
|
||||||
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
|
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
|
||||||
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
|
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
|
||||||
@@ -85,18 +84,18 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https
|
|||||||
|
|
||||||
| Task | TPU v3-8 | 8 GPU | [1 GPU](https://tensorboard.dev/experiment/mkPS4Zh8TnGe1HB6Yzwj4Q) | 1 GPU (Pytorch) |
|
| Task | TPU v3-8 | 8 GPU | [1 GPU](https://tensorboard.dev/experiment/mkPS4Zh8TnGe1HB6Yzwj4Q) | 1 GPU (Pytorch) |
|
||||||
|-------|-----------|------------|------------|-----------------|
|
|-------|-----------|------------|------------|-----------------|
|
||||||
| CoLA | 1m 46s | 1m 26s | 3m 9s | 4m 6s |
|
| CoLA | 1m 42s | 1m 26s | 3m 9s | 4m 6s |
|
||||||
| SST-2 | 5m 30s | 6m 28s | 22m 33s | 34m 37s |
|
| SST-2 | 5m 12s | 6m 28s | 22m 33s | 34m 37s |
|
||||||
| MRPC | 1m 32s | 1m 14s | 2m 20s | 2m 56s |
|
| MRPC | 1m 29s | 1m 14s | 2m 20s | 2m 56s |
|
||||||
| STS-B | 1m 33s | 1m 12s | 2m 16s | 2m 48s |
|
| STS-B | 1m 30s | 1m 12s | 2m 16s | 2m 48s |
|
||||||
| QQP | 24m 40s | 31m 48s | 1h 59m 41s | 2h 54m |
|
| QQP | 22m 50s | 31m 48s | 1h 59m 41s | 2h 54m |
|
||||||
| MNLI | 26m 30s | 33m 55s | 2h 9m 37s | 3h 7m 6s |
|
| MNLI | 25m 03s | 33m 55s | 2h 9m 37s | 3h 7m 6s |
|
||||||
| QNLI | 8m | 9m 40s | 34m 40s | 49m 8s |
|
| QNLI | 7m30s | 9m 40s | 34m 40s | 49m 8s |
|
||||||
| RTE | 1m 21s | 55s | 1m 10s | 1m 16s |
|
| RTE | 1m 20s | 55s | 1m 10s | 1m 16s |
|
||||||
| WNLI | 1m 12s | 48s | 39s | 36s |
|
| WNLI | 1m 11s | 48s | 39s | 36s |
|
||||||
|-------|
|
|-------|
|
||||||
| **TOTAL** | 1h 13m | 1h 28m | 5h 16m | 6h 37m |
|
| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
|
||||||
| **COST*** | $9.60 | $29.10 | $13.06 | $16.41 |
|
| **COST*** | $8.56 | $29.10 | $13.06 | $16.41 |
|
||||||
|
|
||||||
|
|
||||||
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
||||||
@@ -106,4 +105,4 @@ the following tables:
|
|||||||
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
|
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
|
||||||
V100 GPU). GPU experiments are ran without further optimizations besides JAX
|
V100 GPU). GPU experiments are ran without further optimizations besides JAX
|
||||||
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
||||||
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from flax import struct, traverse_util
|
|||||||
from flax.jax_utils import replicate, unreplicate
|
from flax.jax_utils import replicate, unreplicate
|
||||||
from flax.metrics import tensorboard
|
from flax.metrics import tensorboard
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
|
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -407,6 +407,7 @@ def main():
|
|||||||
|
|
||||||
num_epochs = int(args.num_train_epochs)
|
num_epochs = int(args.num_train_epochs)
|
||||||
rng = jax.random.PRNGKey(args.seed)
|
rng = jax.random.PRNGKey(args.seed)
|
||||||
|
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||||
|
|
||||||
train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
|
train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
|
||||||
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
|
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
|
||||||
@@ -424,6 +425,7 @@ def main():
|
|||||||
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
|
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
|
||||||
) -> Tuple[train_state.TrainState, float]:
|
) -> Tuple[train_state.TrainState, float]:
|
||||||
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
|
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
|
||||||
|
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||||
targets = batch.pop("labels")
|
targets = batch.pop("labels")
|
||||||
|
|
||||||
def loss_fn(params):
|
def loss_fn(params):
|
||||||
@@ -436,7 +438,7 @@ def main():
|
|||||||
grad = jax.lax.pmean(grad, "batch")
|
grad = jax.lax.pmean(grad, "batch")
|
||||||
new_state = state.apply_gradients(grads=grad)
|
new_state = state.apply_gradients(grads=grad)
|
||||||
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
|
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
|
||||||
return new_state, metrics
|
return new_state, metrics, new_dropout_rng
|
||||||
|
|
||||||
p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
|
p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
|
||||||
|
|
||||||
@@ -467,9 +469,7 @@ def main():
|
|||||||
|
|
||||||
# train
|
# train
|
||||||
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
|
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
|
||||||
rng, dropout_rng = jax.random.split(rng)
|
state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
|
||||||
dropout_rngs = shard_prng_key(dropout_rng)
|
|
||||||
state, metrics = p_train_step(state, batch, dropout_rngs)
|
|
||||||
train_metrics.append(metrics)
|
train_metrics.append(metrics)
|
||||||
train_time += time.time() - train_start
|
train_time += time.time() - train_start
|
||||||
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
|
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user