From bd9871657bb9500a9f4437a873db6df5f1ae6dbb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 21 May 2021 09:36:56 +0100 Subject: [PATCH] [Flax] Align GLUE training script with mlm training script (#11778) * speed up flax glue * remove unnecessary line * remove folder * remove run in loop Co-authored-by: Patrick von Platen --- examples/flax/text-classification/README.md | 45 +++++++++---------- .../flax/text-classification/run_flax_glue.py | 10 ++--- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/examples/flax/text-classification/README.md b/examples/flax/text-classification/README.md index 79eb4e00de..9bcced8365 100644 --- a/examples/flax/text-classification/README.md +++ b/examples/flax/text-classification/README.md @@ -59,20 +59,19 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models are undertrained and we could get better results when training longer. -In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing). - +In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1p3XzReMO75m_XdEJvPue-PIq_PN-96J2IJpJW1yS-10/edit?usp=sharing). | Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics | |-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------| -| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) | -| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) | -| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) | -| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) | -| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) | -| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) | -| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) | -| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) | -| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) | +| CoLA | Matthew's corr | 60.82 | 59.04 | 1.17 | [tfhub.dev](https://tensorboard.dev/experiment/U2ncNFP3RpWW6YnA9PYJBA/) | +| SST-2 | Accuracy | 92.43 | 92.13 | 0.38 | [tfhub.dev](https://tensorboard.dev/experiment/vzxoOHZURcm0rO1I33x7uA/) | +| MRPC | F1/Accuracy | 89.90/88.98 | 88.98/85.30 | 0.73/2.33 | [tfhub.dev](https://tensorboard.dev/experiment/EWPBIbfYSDGHjiYxrw2a2Q/) | +| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/3aYHKL10TeiaZYwH1M8ogA/) | +| QQP | Accuracy/F1 | 90.82/87.54 | 90.75/87.53 | 0.06/0.02 | [tfhub.dev](https://tensorboard.dev/experiment/VfVDLS4AQnqr4NMbng6yUw/) | +| MNLI | Matched acc. | 84.10 | 83.84 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/Sz9UdhoORaaSjzuOHRB4Jw/) | +| QNLI | Accuracy | 91.07 | 90.83 | 0.19 | [tfhub.dev](https://tensorboard.dev/experiment/zk6udb5MQAyAQ4eczrFBaQ/) | +| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/BwxaUoAEQ5aa3oQilEjADw/) | +| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/b2Y8ouwMTRC8iBWzRzVYTA/) | Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. @@ -85,18 +84,18 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https | Task | TPU v3-8 | 8 GPU | [1 GPU](https://tensorboard.dev/experiment/mkPS4Zh8TnGe1HB6Yzwj4Q) | 1 GPU (Pytorch) | |-------|-----------|------------|------------|-----------------| -| CoLA | 1m 46s | 1m 26s | 3m 9s | 4m 6s | -| SST-2 | 5m 30s | 6m 28s | 22m 33s | 34m 37s | -| MRPC | 1m 32s | 1m 14s | 2m 20s | 2m 56s | -| STS-B | 1m 33s | 1m 12s | 2m 16s | 2m 48s | -| QQP | 24m 40s | 31m 48s | 1h 59m 41s | 2h 54m | -| MNLI | 26m 30s | 33m 55s | 2h 9m 37s | 3h 7m 6s | -| QNLI | 8m | 9m 40s | 34m 40s | 49m 8s | -| RTE | 1m 21s | 55s | 1m 10s | 1m 16s | -| WNLI | 1m 12s | 48s | 39s | 36s | +| CoLA | 1m 42s | 1m 26s | 3m 9s | 4m 6s | +| SST-2 | 5m 12s | 6m 28s | 22m 33s | 34m 37s | +| MRPC | 1m 29s | 1m 14s | 2m 20s | 2m 56s | +| STS-B | 1m 30s | 1m 12s | 2m 16s | 2m 48s | +| QQP | 22m 50s | 31m 48s | 1h 59m 41s | 2h 54m | +| MNLI | 25m 03s | 33m 55s | 2h 9m 37s | 3h 7m 6s | +| QNLI | 7m30s | 9m 40s | 34m 40s | 49m 8s | +| RTE | 1m 20s | 55s | 1m 10s | 1m 16s | +| WNLI | 1m 11s | 48s | 39s | 36s | |-------| -| **TOTAL** | 1h 13m | 1h 28m | 5h 16m | 6h 37m | -| **COST*** | $9.60 | $29.10 | $13.06 | $16.41 | +| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m | +| **COST*** | $8.56 | $29.10 | $13.06 | $16.41 | *All experiments are ran on Google Cloud Platform. Prices are on-demand prices @@ -106,4 +105,4 @@ the following tables: [GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per V100 GPU). GPU experiments are ran without further optimizations besides JAX transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8" -are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips. \ No newline at end of file +are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips. diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index bf5bb0acac..0a0722863d 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -34,7 +34,7 @@ from flax import struct, traverse_util from flax.jax_utils import replicate, unreplicate from flax.metrics import tensorboard from flax.training import train_state -from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from flax.training.common_utils import get_metrics, onehot, shard from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig @@ -407,6 +407,7 @@ def main(): num_epochs = int(args.num_train_epochs) rng = jax.random.PRNGKey(args.seed) + dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = args.per_device_train_batch_size * jax.local_device_count() eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count() @@ -424,6 +425,7 @@ def main(): state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): @@ -436,7 +438,7 @@ def main(): grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") - return new_state, metrics + return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) @@ -467,9 +469,7 @@ def main(): # train for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): - rng, dropout_rng = jax.random.split(rng) - dropout_rngs = shard_prng_key(dropout_rng) - state, metrics = p_train_step(state, batch, dropout_rngs) + state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(metrics) train_time += time.time() - train_start logger.info(f" Done! Training metrics: {unreplicate(metrics)}")