[Flax] Adapt flax examples to include push_to_hub (#12391)

* fix_torch_device_generate_test

* remove @

* finish

* correct summary writer

* correct push to hub

* fix indent

* finish

* finish

* finish

* finish

* finish

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-06-28 19:23:35 +01:00
committed by GitHub
parent a7d0b288fa
commit 2d70c91206
6 changed files with 153 additions and 48 deletions

View File

@@ -269,7 +269,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
return batch_idx
def write_metric(train_metrics, eval_metrics, train_time, step):
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
@@ -472,7 +472,7 @@ if __name__ == "__main__":
# Enable tensorboard only on the master node
if has_tensorboard and jax.process_index() == 0:
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
# Data collator
# This one will take care of randomly masking the tokens.
@@ -642,9 +642,14 @@ if __name__ == "__main__":
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_metric(train_metrics, eval_metrics, train_time, cur_step)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save last checkpoint
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)