From deecdd4939f98394703617cd485fea8d9f986f8c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 9 Jul 2021 13:51:28 +0100 Subject: [PATCH] [Flax] Fix cur step flax examples (#12608) * fix_torch_device_generate_test * remove @ * fix save problem --- examples/flax/language-modeling/run_clm_flax.py | 1 - examples/flax/language-modeling/run_mlm_flax.py | 1 - examples/flax/language-modeling/run_t5_mlm_flax.py | 1 - 3 files changed, 3 deletions(-) diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 1367ad2db9..bddd5b9905 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -622,7 +622,6 @@ def main(): # Save metrics if has_tensorboard and jax.process_index() == 0: - cur_step = epoch * (len(train_dataset) // train_batch_size) write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 56cc35d969..4282560dac 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -663,7 +663,6 @@ if __name__ == "__main__": # Save metrics if has_tensorboard and jax.process_index() == 0: - cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index f04f586001..001bea329a 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -771,7 +771,6 @@ if __name__ == "__main__": # Save metrics if has_tensorboard and jax.process_index() == 0: - cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: