[FLAX] glue training example refactor (#13815)

* refactor run_flax_glue.py

* updated readme

* rm unused import and args typo fix

* refactor

* make consistent arg name across task

* has_tensorboard check

* argparse -> argument dataclasses

* refactor according to review

* fix
This commit is contained in:
Kamal Raj
2022-01-19 16:34:51 +05:30
committed by GitHub
parent db3503949d
commit d1f5ca1afd
3 changed files with 238 additions and 164 deletions

View File

@@ -33,15 +33,16 @@ export TASK_NAME=mrpc
python run_flax_glue.py \
--model_name_or_path bert-base-cased \
--task_name ${TASK_NAME} \
--max_length 128 \
--max_seq_length 128 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--eval_steps 100 \
--output_dir ./$TASK_NAME/ \
--push_to_hub
```
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.