correct example script (#11726)
This commit is contained in:
committed by
GitHub
parent
bd3b599c12
commit
113eaa7575
@@ -119,12 +119,6 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_accumulation_steps",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||||
)
|
)
|
||||||
@@ -457,13 +451,13 @@ def main():
|
|||||||
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
|
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
|
||||||
train_time = 0
|
train_time = 0
|
||||||
|
|
||||||
|
# make sure weights are replicated on each device
|
||||||
|
state = replicate(state)
|
||||||
|
|
||||||
for epoch in range(1, num_epochs + 1):
|
for epoch in range(1, num_epochs + 1):
|
||||||
logger.info(f"Epoch {epoch}")
|
logger.info(f"Epoch {epoch}")
|
||||||
logger.info(" Training...")
|
logger.info(" Training...")
|
||||||
|
|
||||||
# make sure weights are replicated on each device
|
|
||||||
state = replicate(state)
|
|
||||||
|
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
|
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
|
||||||
@@ -501,6 +495,9 @@ def main():
|
|||||||
predictions = eval_step(state, batch)
|
predictions = eval_step(state, batch)
|
||||||
metric.add_batch(predictions=predictions, references=labels)
|
metric.add_batch(predictions=predictions, references=labels)
|
||||||
|
|
||||||
|
# make sure weights are replicated on each device
|
||||||
|
state = replicate(state)
|
||||||
|
|
||||||
eval_metric = metric.compute()
|
eval_metric = metric.compute()
|
||||||
logger.info(f" Done! Eval metrics: {eval_metric}")
|
logger.info(f" Done! Eval metrics: {eval_metric}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user