[FLAX] glue training example refactor (#13815)
* refactor run_flax_glue.py * updated readme * rm unused import and args typo fix * refactor * make consistent arg name across task * has_tensorboard check * argparse -> argument dataclasses * refactor according to review * fix
This commit is contained in:
@@ -33,15 +33,16 @@ export TASK_NAME=mrpc
|
||||
python run_flax_glue.py \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name ${TASK_NAME} \
|
||||
--max_length 128 \
|
||||
--max_seq_length 128 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--eval_steps 100 \
|
||||
--output_dir ./$TASK_NAME/ \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
|
||||
where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
|
||||
|
||||
Using the command above, the script will train for 3 epochs and run eval after each epoch.
|
||||
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
|
||||
|
||||
Reference in New Issue
Block a user