From 4605b2b8ec5512a5ea125773bcaa4b0014b32d50 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jul 2021 18:35:22 +0100 Subject: [PATCH] [Flax] Fix another bug in logging steps (#12516) * fix_torch_device_generate_test * remove @ * up --- examples/flax/language-modeling/run_mlm_flax.py | 2 +- examples/flax/language-modeling/run_t5_mlm_flax.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 33808e73b8..3bb74d1a06 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -606,7 +606,7 @@ if __name__ == "__main__": state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) train_metrics.append(train_metric) - cur_step = epoch * num_train_samples + step + cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index d1c262cdf4..dc87f0093a 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -722,7 +722,7 @@ if __name__ == "__main__": state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) train_metrics.append(train_metric) - cur_step = epoch * num_train_samples + step + cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics