From 726e953d44def8a51d9a4b183bd81fc0c506091e Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Mon, 17 May 2021 10:26:33 +0200 Subject: [PATCH] Improvements to Flax finetuning script (#11727) * Add Cloud details to README * Flax script and readme updates * Some simplifications of Flax script --- examples/flax/text-classification/README.md | 20 ++++++------ .../flax/text-classification/run_flax_glue.py | 31 +++++++++---------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/examples/flax/text-classification/README.md b/examples/flax/text-classification/README.md index 14c4603e5a..79eb4e00de 100644 --- a/examples/flax/text-classification/README.md +++ b/examples/flax/text-classification/README.md @@ -59,20 +59,20 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models are undertrained and we could get better results when training longer. -In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1zKL_xn32HwbxkFMxB3ftca-soTHAuBFgIhYhOhCnZ4E/edit?usp=sharing). +In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing). | Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics | |-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------| -| CoLA | Matthew's corr | 59.57 | 58.04 | 1.81 | [tfhub.dev](https://tensorboard.dev/experiment/f4OvQpWtRq6CvddpxGBd0A/) | -| SST-2 | Accuracy | 92.43 | 91.79 | 0.59 | [tfhub.dev](https://tensorboard.dev/experiment/BYFwa49MRTaLIn93DgAEtA/) | -| MRPC | F1/Accuracy | 89.50/84.8 | 88.70/84.02 | 0.56/0.48 | [tfhub.dev](https://tensorboard.dev/experiment/9ZWH5xwXRS6zEEUE4RaBhQ/) | -| STS-B | Pearson/Spearman corr. | 90.00/88.71 | 89.09/88.61 | 0.51/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/mUlI5B9QQ0WGEJip7p3Tng/) | -| QQP | Accuracy/F1 | 90.88/87.64 | 90.75/87.53 | 0.11/0.13 | [tfhub.dev](https://tensorboard.dev/experiment/pO6h75L3SvSXSWRcgljXKA/) | -| MNLI | Matched acc. | 84.06 | 83.88 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/LKwaOH18RMuo7nJkESrpKg/) | -| QNLI | Accuracy | 91.01 | 90.86 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/qesXxNcaQhmKxPmbw1sOoA/) | -| RTE | Accuracy | 66.80 | 65.27 | 1.07 | [tfhub.dev](https://tensorboard.dev/experiment/Z84xC0r6RjyzT4SLqiAbzQ/) | -| WNLI | Accuracy | 39.44 | 32.96 | 5.85 | [tfhub.dev](https://tensorboard.dev/experiment/gV73w9v0RIKrqVw32PZbAQ/) | +| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) | +| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) | +| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) | +| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) | +| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) | +| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) | +| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) | +| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) | +| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) | Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index f3453926fe..bf5bb0acac 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -123,7 +123,7 @@ def parse_args(): "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") - parser.add_argument("--seed", type=int, default=2, help="A seed for reproducible training.") + parser.add_argument("--seed", type=int, default=5, help="A seed for reproducible training.") args = parser.parse_args() # Sanity checks @@ -148,6 +148,7 @@ def create_train_state( learning_rate_fn: Callable[[int], float], is_regression: bool, num_labels: int, + weight_decay: float, ) -> train_state.TrainState: """Create initial training state.""" @@ -166,8 +167,8 @@ def create_train_state( loss_fn: Callable = struct.field(pytree_node=False) # Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers. - def adamw(weight_decay): - return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay) + def adamw(decay): + return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=decay) def traverse(fn): def mask(data): @@ -183,7 +184,7 @@ def create_train_state( tx = optax.chain( optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))), - optax.masked(adamw(0.01), mask=traverse(lambda path, _: not decay_path(path))), + optax.masked(adamw(weight_decay), mask=traverse(lambda path, _: not decay_path(path))), ) if is_regression: @@ -414,7 +415,9 @@ def main(): len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate ) - state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels) + state = create_train_state( + model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay + ) # define step functions def train_step( @@ -426,10 +429,10 @@ def main(): def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) - return loss, logits + return loss - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (loss, logits), grad = grad_fn(state.params) + grad_fn = jax.value_and_grad(loss_fn) + loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") @@ -460,10 +463,11 @@ def main(): train_start = time.time() train_metrics = [] - rng, input_rng, dropout_rng = jax.random.split(rng, 3) + rng, input_rng = jax.random.split(rng) # train for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): + rng, dropout_rng = jax.random.split(rng) dropout_rngs = shard_prng_key(dropout_rng) state, metrics = p_train_step(state, batch, dropout_rngs) train_metrics.append(metrics) @@ -471,7 +475,6 @@ def main(): logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(" Evaluating...") - rng, input_rng = jax.random.split(rng) # evaluate for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): @@ -484,20 +487,14 @@ def main(): # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: - # put weights on single device - state = unreplicate(state) - # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: jnp.array(v) for k, v in batch.items()} labels = batch.pop("labels") - predictions = eval_step(state, batch) + predictions = eval_step(unreplicate(state), batch) metric.add_batch(predictions=predictions, references=labels) - # make sure weights are replicated on each device - state = replicate(state) - eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}")