From 82335185febc7bd27294f7fd0024c103d5dd502a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 21 May 2021 16:52:23 +0100 Subject: [PATCH] [Flax] Small fixes in `run_flax_glue.py` (#11820) * fix_torch_device_generate_test * remove @ * correct best seed for flax fine-tuning Co-authored-by: Patrick von Platen --- examples/flax/text-classification/README.md | 2 +- examples/flax/text-classification/run_flax_glue.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/flax/text-classification/README.md b/examples/flax/text-classification/README.md index 9bcced8365..50b4fd2f5d 100644 --- a/examples/flax/text-classification/README.md +++ b/examples/flax/text-classification/README.md @@ -59,7 +59,7 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models are undertrained and we could get better results when training longer. -In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1p3XzReMO75m_XdEJvPue-PIq_PN-96J2IJpJW1yS-10/edit?usp=sharing). +In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 3, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1p3XzReMO75m_XdEJvPue-PIq_PN-96J2IJpJW1yS-10/edit?usp=sharing). | Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics | |-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------| diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 0a0722863d..24aac7defd 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -123,7 +123,7 @@ def parse_args(): "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") - parser.add_argument("--seed", type=int, default=5, help="A seed for reproducible training.") + parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.") args = parser.parse_args() # Sanity checks