[Flax] Adapt flax examples to include push_to_hub (#12391)
* fix_torch_device_generate_test * remove @ * finish * correct summary writer * correct push to hub * fix indent * finish * finish * finish * finish * finish Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
a7d0b288fa
commit
2d70c91206
@@ -23,31 +23,68 @@ Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transfor
|
||||
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models).
|
||||
|
||||
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
||||
To begin with it is recommended to create a model repository to save the trained model and logs.
|
||||
Here we call the model `"bert-glue-mrpc-test"`, but you can change the model name as you like.
|
||||
|
||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||
you are logged in) or via the command line:
|
||||
|
||||
```
|
||||
huggingface-cli repo create bert-glue-mrpc-test
|
||||
```
|
||||
|
||||
Next we clone the model repository to add the tokenizer and model files.
|
||||
|
||||
```
|
||||
git clone https://huggingface.co/<your-username>/bert-glue-mrpc-test
|
||||
```
|
||||
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
|
||||
```
|
||||
cd bert-glue-mrpc-test
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_flax_glue.py`.
|
||||
|
||||
```bash
|
||||
export TASK_NAME=mrpc
|
||||
export MODEL_DIR="./bert-glue-mrpc-test"
|
||||
ln -s ~/transformers/examples/flax/text-classification/run_flax_glue.py run_flax_glue.py
|
||||
```
|
||||
|
||||
|
||||
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
||||
|
||||
```bash
|
||||
python run_flax_glue.py \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name $TASK_NAME \
|
||||
--task_name ${TASK_NAME} \
|
||||
--max_length 128 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--output_dir /tmp/$TASK_NAME/
|
||||
--output_dir ${MODEL_DIR} \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
|
||||
|
||||
Using the command above, the script will train for 3 epochs and run eval after each epoch.
|
||||
Metrics and hyperparameters are stored in Tensorflow event files in `---output_dir`.
|
||||
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
|
||||
You can see the results by running `tensorboard` in that directory:
|
||||
|
||||
```bash
|
||||
$ tensorboard --logdir .
|
||||
```
|
||||
|
||||
or directly on the hub under *Training metrics*.
|
||||
|
||||
### Accuracy Evaluation
|
||||
|
||||
We train five replicas and report mean accuracy and stdev on the dev set below.
|
||||
@@ -95,14 +132,8 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https
|
||||
| WNLI | 1m 11s | 48s | 39s | 36s |
|
||||
|-------|
|
||||
| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
|
||||
| **COST*** | $8.56 | $29.10 | $13.06 | $16.41 |
|
||||
|
||||
|
||||
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
||||
(not preemptible), obtained on May 12, 2021 for zone Iowa (us-central1) using
|
||||
the following tables:
|
||||
[TPU pricing table](https://cloud.google.com/tpu/pricing) ($8.00/h for v3-8),
|
||||
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
|
||||
V100 GPU). GPU experiments are ran without further optimizations besides JAX
|
||||
*All experiments are ran on Google Cloud Platform.
|
||||
GPU experiments are ran without further optimizations besides JAX
|
||||
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
||||
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
||||
|
||||
Reference in New Issue
Block a user