From 113eaa757589d31ee355c6c5d1fbd29e7e8ba788 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 14 May 2021 12:02:57 +0100 Subject: [PATCH] correct example script (#11726) --- .../flax/text-classification/run_flax_glue.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index f405dd9fc7..f3453926fe 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -119,12 +119,6 @@ def parse_args(): default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) parser.add_argument( "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." ) @@ -457,13 +451,13 @@ def main(): logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 + # make sure weights are replicated on each device + state = replicate(state) + for epoch in range(1, num_epochs + 1): logger.info(f"Epoch {epoch}") logger.info(" Training...") - # make sure weights are replicated on each device - state = replicate(state) - train_start = time.time() train_metrics = [] rng, input_rng, dropout_rng = jax.random.split(rng, 3) @@ -501,6 +495,9 @@ def main(): predictions = eval_step(state, batch) metric.add_batch(predictions=predictions, references=labels) + # make sure weights are replicated on each device + state = replicate(state) + eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}")