[Flax] token-classification model steps enumerate start from 1 (#14547)
* step start from 1 * Updated cur_step calcualtion
This commit is contained in:
@@ -598,7 +598,7 @@ def main():
|
|||||||
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
|
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
|
||||||
train_metrics.append(train_metric)
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
cur_step = epoch * step_per_epoch + step
|
cur_step = (epoch * step_per_epoch) + (step + 1)
|
||||||
|
|
||||||
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||||
# Save metrics
|
# Save metrics
|
||||||
|
|||||||
Reference in New Issue
Block a user