[examples/flax] use Repository API for push_to_hub (#13672)

* use Repository for push_to_hub

* update readme

* update other flax scripts

* update readme

* update qa example

* fix push_to_hub call

* fix typo

* fix more typos

* update readme

* use abosolute path to get repo name

* fix glue script
This commit is contained in:
Suraj Patil
2021-09-30 16:38:07 +05:30
committed by GitHub
parent b90096fe14
commit 7db2a79b38
15 changed files with 183 additions and 292 deletions

View File

@@ -41,6 +41,7 @@ import optax
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
@@ -54,6 +55,7 @@ from transformers import (
is_tensorboard_available,
set_seed,
)
from transformers.file_utils import get_full_repo_name
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
@@ -308,6 +310,16 @@ if __name__ == "__main__":
# Set seed before initializing model.
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
@@ -683,9 +695,7 @@ if __name__ == "__main__":
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)