[Flax] Correct logging steps flax (#12515)
* fix_torch_device_generate_test * remove @ * push
This commit is contained in:
committed by
GitHub
parent
bb4ac2b5a8
commit
d0f7508abe
@@ -574,7 +574,7 @@ def main():
|
||||
|
||||
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
|
||||
|
||||
if cur_step % training_args.logging_steps and cur_step > 0:
|
||||
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||
# Save metrics
|
||||
train_metric = unreplicate(train_metric)
|
||||
train_time += time.time() - train_start
|
||||
|
||||
@@ -608,7 +608,7 @@ if __name__ == "__main__":
|
||||
|
||||
cur_step = epoch * num_train_samples + step
|
||||
|
||||
if cur_step % training_args.logging_steps and cur_step > 0:
|
||||
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||
# Save metrics
|
||||
train_metric = jax_utils.unreplicate(train_metric)
|
||||
train_time += time.time() - train_start
|
||||
|
||||
@@ -724,7 +724,7 @@ if __name__ == "__main__":
|
||||
|
||||
cur_step = epoch * num_train_samples + step
|
||||
|
||||
if cur_step % training_args.logging_steps and cur_step > 0:
|
||||
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||
# Save metrics
|
||||
train_metric = jax_utils.unreplicate(train_metric)
|
||||
train_time += time.time() - train_start
|
||||
|
||||
Reference in New Issue
Block a user