[examples/flax] use Repository API for push_to_hub (#13672)
* use Repository for push_to_hub * update readme * update other flax scripts * update readme * update qa example * fix push_to_hub call * fix typo * fix more typos * update readme * use abosolute path to get repo name * fix glue script
This commit is contained in:
@@ -11,43 +11,12 @@ way which enables simple and efficient model parallelism.
|
||||
|
||||
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
|
||||
|
||||
Let's start by creating a model repository to save the trained model and logs.
|
||||
Here we call the model `"bart-base-xsum"`, but you can change the model name as you like.
|
||||
|
||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||
you are logged in) or via the command line:
|
||||
|
||||
```
|
||||
huggingface-cli repo create bart-base-xsum
|
||||
```
|
||||
Next we clone the model repository to add the tokenizer and model files.
|
||||
```
|
||||
git clone https://huggingface.co/<your-username>/bart-base-xsum
|
||||
```
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
|
||||
```
|
||||
cd bart-base-xsum
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_summarization_flax.py`.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./bart-base-xsum"
|
||||
ln -s ~/transformers/examples/flax/summarization/run_summarization_flax.py run_summarization_flax.py
|
||||
```
|
||||
|
||||
### Train the model
|
||||
Next we can run the example script to train the model:
|
||||
|
||||
```bash
|
||||
python run_summarization_flax.py \
|
||||
--output_dir ${MODEL_DIR} \
|
||||
--output_dir ./bart-base-xsum \
|
||||
--model_name_or_path facebook/bart-base \
|
||||
--tokenizer_name facebook/bart-base \
|
||||
--dataset_name="xsum" \
|
||||
|
||||
@@ -42,6 +42,7 @@ from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
@@ -52,7 +53,7 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
)
|
||||
from transformers.file_utils import is_offline_mode
|
||||
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -333,6 +334,16 @@ def main():
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Handle the repository creation
|
||||
if training_args.push_to_hub:
|
||||
if training_args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(
|
||||
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||
)
|
||||
else:
|
||||
repo_name = training_args.hub_model_id
|
||||
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
@@ -800,12 +811,10 @@ def main():
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(
|
||||
training_args.output_dir,
|
||||
params=params,
|
||||
push_to_hub=training_args.push_to_hub,
|
||||
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
||||
)
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user