From a6ea244f99c4f59b946cc7a25eee9461c42fb990 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 5 Oct 2021 13:00:13 +0200 Subject: [PATCH] Fix: save checkpoint after each epoch and push checkpoint to the hub (#13872) Co-authored-by: ydshieh --- .../flax/summarization/run_summarization_flax.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index 9c72cce216..6b0f3becda 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -769,6 +769,14 @@ def main(): cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) + # ======================== Prediction loop ============================== if training_args.do_predict: logger.info("*** Predict ***") @@ -808,14 +816,6 @@ def main(): desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" logger.info(desc) - # save checkpoint after each epoch and push checkpoint to the hub - if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained(training_args.output_dir, params=params) - tokenizer.save_pretrained(training_args.output_dir) - if training_args.push_to_hub: - repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) - if __name__ == "__main__": main()