Updates README and fixes bug (#11701)
This commit is contained in:
@@ -83,14 +83,24 @@ We also ran each task once on a single V100 GPU, 8 V100 GPUs, and 8 Cloud v3 TPU
|
|||||||
overall training time below. For comparison we ran Pytorch's [run_glue.py](https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py) on a single GPU (last column).
|
overall training time below. For comparison we ran Pytorch's [run_glue.py](https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py) on a single GPU (last column).
|
||||||
|
|
||||||
|
|
||||||
| Task | 8 TPU | 8 GPU | 1 GPU | 1 GPU (Pytorch) |
|
| Task | TPU v3-8 | 8 GPU | 1 GPU | 1 GPU (Pytorch) |
|
||||||
|-------|---------|---------|------------|-----------------|
|
|-------|-----------|------------|------------|-----------------|
|
||||||
| CoLA | 1m 46s | 1m 26s | 3m 6s | 4m 6s |
|
| CoLA | 1m 46s | 1m 26s | 3m 6s | 4m 6s |
|
||||||
| SST-2 | 5m 30s | 6m 28s | 22m 6s | 34m 37s |
|
| SST-2 | 5m 30s | 6m 28s | 22m 6s | 34m 37s |
|
||||||
| MRPC | 1m 32s | 1m 14s | 2m 17s | 2m 56s |
|
| MRPC | 1m 32s | 1m 14s | 2m 17s | 2m 56s |
|
||||||
| STS-B | 1m 33s | 1m 12s | 2m 11s | 2m 48s |
|
| STS-B | 1m 33s | 1m 12s | 2m 11s | 2m 48s |
|
||||||
| QQP | 24m 40s | 31m 48s | 1h 20m 15s | 2h 54m |
|
| QQP | 24m 40s | 31m 48s | 1h 20m 15s | 2h 54m |
|
||||||
| MNLI | 26m 30s | 33m 55s | 2h 7m 30s | 3u 7m 6s |
|
| MNLI | 26m 30s | 33m 55s | 2h 7m 30s | 3h 7m 6s |
|
||||||
| QNLI | 8m | 9m 40s | 34m 20s | 49m 8s |
|
| QNLI | 8m | 9m 40s | 34m 20s | 49m 8s |
|
||||||
| RTE | 1m 21s | 55s | 1m 8s | 1m 16s |
|
| RTE | 1m 21s | 55s | 1m 8s | 1m 16s |
|
||||||
| WNLI | 1m 12s | 48s | 38s | 36s |
|
| WNLI | 1m 12s | 48s | 38s | 36s |
|
||||||
|
|-------|
|
||||||
|
| **TOTAL** | 1h 13m | 1h 28m | 4h 34m | 6h 37m |
|
||||||
|
| **COST*** | $9.60 | $29.10 | $11.33 | $16.41 |
|
||||||
|
|
||||||
|
|
||||||
|
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
||||||
|
(not preemptible), obtained from the following tables:
|
||||||
|
[TPU pricing table](https://cloud.google.com/tpu/pricing),
|
||||||
|
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing). GPU
|
||||||
|
experiments are ran without further optimizations besides JAX transformations.
|
||||||
@@ -473,8 +473,8 @@ def main():
|
|||||||
dropout_rngs = shard_prng_key(dropout_rng)
|
dropout_rngs = shard_prng_key(dropout_rng)
|
||||||
state, metrics = p_train_step(state, batch, dropout_rngs)
|
state, metrics = p_train_step(state, batch, dropout_rngs)
|
||||||
train_metrics.append(metrics)
|
train_metrics.append(metrics)
|
||||||
train_time += time.time() - train_start
|
train_time += time.time() - train_start
|
||||||
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
|
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
|
||||||
|
|
||||||
logger.info(" Evaluating...")
|
logger.info(" Evaluating...")
|
||||||
rng, input_rng = jax.random.split(rng)
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
|||||||
Reference in New Issue
Block a user