[Flax] Correct logging steps flax (#12515)

* fix_torch_device_generate_test

* remove @

* push
This commit is contained in:
Patrick von Platen
2021-07-05 18:21:00 +01:00
committed by GitHub
parent bb4ac2b5a8
commit d0f7508abe
3 changed files with 3 additions and 3 deletions

View File

@@ -608,7 +608,7 @@ if __name__ == "__main__":
cur_step = epoch * num_train_samples + step
if cur_step % training_args.logging_steps and cur_step > 0:
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start