From d0f7508abeb9caccc737c060350b22fb5568ef97 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jul 2021 18:21:00 +0100 Subject: [PATCH] [Flax] Correct logging steps flax (#12515) * fix_torch_device_generate_test * remove @ * push --- examples/flax/language-modeling/run_clm_flax.py | 2 +- examples/flax/language-modeling/run_mlm_flax.py | 2 +- examples/flax/language-modeling/run_t5_mlm_flax.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 8ade7d4284..b63612bd93 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -574,7 +574,7 @@ def main(): cur_step = epoch * (len(train_dataset) // train_batch_size) + step - if cur_step % training_args.logging_steps and cur_step > 0: + if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 5d9fda11a3..33808e73b8 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -608,7 +608,7 @@ if __name__ == "__main__": cur_step = epoch * num_train_samples + step - if cur_step % training_args.logging_steps and cur_step > 0: + if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 795dc7faeb..d1c262cdf4 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -724,7 +724,7 @@ if __name__ == "__main__": cur_step = epoch * num_train_samples + step - if cur_step % training_args.logging_steps and cur_step > 0: + if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start