diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index 9c72cce216..6b0f3becda 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -769,6 +769,14 @@ def main(): cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) + # ======================== Prediction loop ============================== if training_args.do_predict: logger.info("*** Predict ***") @@ -808,14 +816,6 @@ def main(): desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" logger.info(desc) - # save checkpoint after each epoch and push checkpoint to the hub - if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained(training_args.output_dir, params=params) - tokenizer.save_pretrained(training_args.output_dir) - if training_args.push_to_hub: - repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) - if __name__ == "__main__": main()