[Flax] Adapt flax examples to include push_to_hub (#12391)

* fix_torch_device_generate_test

* remove @

* finish

* correct summary writer

* correct push to hub

* fix indent

* finish

* finish

* finish

* finish

* finish

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-06-28 19:23:35 +01:00
committed by GitHub
parent a7d0b288fa
commit 2d70c91206
6 changed files with 153 additions and 48 deletions

View File

@@ -542,7 +542,7 @@ def main():
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
@@ -787,10 +787,15 @@ def main():
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
logger.info(desc)
# save last checkpoint
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(training_args.output_dir, params=params)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__":