correct example script (#11726)
This commit is contained in:
committed by
GitHub
parent
bd3b599c12
commit
113eaa7575
@@ -119,12 +119,6 @@ def parse_args():
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
@@ -457,13 +451,13 @@ def main():
|
||||
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
|
||||
train_time = 0
|
||||
|
||||
# make sure weights are replicated on each device
|
||||
state = replicate(state)
|
||||
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
logger.info(f"Epoch {epoch}")
|
||||
logger.info(" Training...")
|
||||
|
||||
# make sure weights are replicated on each device
|
||||
state = replicate(state)
|
||||
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
|
||||
@@ -501,6 +495,9 @@ def main():
|
||||
predictions = eval_step(state, batch)
|
||||
metric.add_batch(predictions=predictions, references=labels)
|
||||
|
||||
# make sure weights are replicated on each device
|
||||
state = replicate(state)
|
||||
|
||||
eval_metric = metric.compute()
|
||||
logger.info(f" Done! Eval metrics: {eval_metric}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user