correct example script (#11726)

This commit is contained in:
Patrick von Platen
2021-05-14 12:02:57 +01:00
committed by GitHub
parent bd3b599c12
commit 113eaa7575

View File

@@ -119,12 +119,6 @@ def parse_args():
default=None, default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument( parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
) )
@@ -457,13 +451,13 @@ def main():
logger.info(f"===== Starting training ({num_epochs} epochs) =====") logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0 train_time = 0
# make sure weights are replicated on each device
state = replicate(state)
for epoch in range(1, num_epochs + 1): for epoch in range(1, num_epochs + 1):
logger.info(f"Epoch {epoch}") logger.info(f"Epoch {epoch}")
logger.info(" Training...") logger.info(" Training...")
# make sure weights are replicated on each device
state = replicate(state)
train_start = time.time() train_start = time.time()
train_metrics = [] train_metrics = []
rng, input_rng, dropout_rng = jax.random.split(rng, 3) rng, input_rng, dropout_rng = jax.random.split(rng, 3)
@@ -501,6 +495,9 @@ def main():
predictions = eval_step(state, batch) predictions = eval_step(state, batch)
metric.add_batch(predictions=predictions, references=labels) metric.add_batch(predictions=predictions, references=labels)
# make sure weights are replicated on each device
state = replicate(state)
eval_metric = metric.compute() eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}") logger.info(f" Done! Eval metrics: {eval_metric}")