From 6797cdc077782a8d9b94b57620b8b6a832791e80 Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Wed, 12 May 2021 14:52:52 +0200 Subject: [PATCH] Updates README and fixes bug (#11701) --- examples/flax/text-classification/README.md | 32 ++++++++++++------- .../flax/text-classification/run_flax_glue.py | 4 +-- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/examples/flax/text-classification/README.md b/examples/flax/text-classification/README.md index cdb0c905c7..2826735101 100644 --- a/examples/flax/text-classification/README.md +++ b/examples/flax/text-classification/README.md @@ -83,14 +83,24 @@ We also ran each task once on a single V100 GPU, 8 V100 GPUs, and 8 Cloud v3 TPU overall training time below. For comparison we ran Pytorch's [run_glue.py](https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py) on a single GPU (last column). -| Task | 8 TPU | 8 GPU | 1 GPU | 1 GPU (Pytorch) | -|-------|---------|---------|------------|-----------------| -| CoLA | 1m 46s | 1m 26s | 3m 6s | 4m 6s | -| SST-2 | 5m 30s | 6m 28s | 22m 6s | 34m 37s | -| MRPC | 1m 32s | 1m 14s | 2m 17s | 2m 56s | -| STS-B | 1m 33s | 1m 12s | 2m 11s | 2m 48s | -| QQP | 24m 40s | 31m 48s | 1h 20m 15s | 2h 54m | -| MNLI | 26m 30s | 33m 55s | 2h 7m 30s | 3u 7m 6s | -| QNLI | 8m | 9m 40s | 34m 20s | 49m 8s | -| RTE | 1m 21s | 55s | 1m 8s | 1m 16s | -| WNLI | 1m 12s | 48s | 38s | 36s | +| Task | TPU v3-8 | 8 GPU | 1 GPU | 1 GPU (Pytorch) | +|-------|-----------|------------|------------|-----------------| +| CoLA | 1m 46s | 1m 26s | 3m 6s | 4m 6s | +| SST-2 | 5m 30s | 6m 28s | 22m 6s | 34m 37s | +| MRPC | 1m 32s | 1m 14s | 2m 17s | 2m 56s | +| STS-B | 1m 33s | 1m 12s | 2m 11s | 2m 48s | +| QQP | 24m 40s | 31m 48s | 1h 20m 15s | 2h 54m | +| MNLI | 26m 30s | 33m 55s | 2h 7m 30s | 3h 7m 6s | +| QNLI | 8m | 9m 40s | 34m 20s | 49m 8s | +| RTE | 1m 21s | 55s | 1m 8s | 1m 16s | +| WNLI | 1m 12s | 48s | 38s | 36s | +|-------| +| **TOTAL** | 1h 13m | 1h 28m | 4h 34m | 6h 37m | +| **COST*** | $9.60 | $29.10 | $11.33 | $16.41 | + + +*All experiments are ran on Google Cloud Platform. Prices are on-demand prices +(not preemptible), obtained from the following tables: +[TPU pricing table](https://cloud.google.com/tpu/pricing), +[GPU pricing table](https://cloud.google.com/compute/gpus-pricing). GPU +experiments are ran without further optimizations besides JAX transformations. \ No newline at end of file diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 217b7bdc38..f405dd9fc7 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -473,8 +473,8 @@ def main(): dropout_rngs = shard_prng_key(dropout_rng) state, metrics = p_train_step(state, batch, dropout_rngs) train_metrics.append(metrics) - train_time += time.time() - train_start - logger.info(f" Done! Training metrics: {unreplicate(metrics)}") + train_time += time.time() - train_start + logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(" Evaluating...") rng, input_rng = jax.random.split(rng)