[examples/flax] use Repository API for push_to_hub (#13672)
* use Repository for push_to_hub * update readme * update other flax scripts * update readme * update qa example * fix push_to_hub call * fix typo * fix more typos * update readme * use abosolute path to get repo name * fix glue script
This commit is contained in:
@@ -21,47 +21,15 @@ limitations under the License.
|
||||
Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py).
|
||||
|
||||
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models).
|
||||
|
||||
To begin with it is recommended to create a model repository to save the trained model and logs.
|
||||
Here we call the model `"bert-glue-mrpc-test"`, but you can change the model name as you like.
|
||||
|
||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||
you are logged in) or via the command line:
|
||||
|
||||
```
|
||||
huggingface-cli repo create bert-glue-mrpc-test
|
||||
```
|
||||
|
||||
Next we clone the model repository to add the tokenizer and model files.
|
||||
|
||||
```
|
||||
git clone https://huggingface.co/<your-username>/bert-glue-mrpc-test
|
||||
```
|
||||
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
|
||||
```
|
||||
cd bert-glue-mrpc-test
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_flax_glue.py`.
|
||||
|
||||
```bash
|
||||
export TASK_NAME=mrpc
|
||||
export MODEL_DIR="./bert-glue-mrpc-test"
|
||||
ln -s ~/transformers/examples/flax/text-classification/run_flax_glue.py run_flax_glue.py
|
||||
```
|
||||
|
||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) and can also be used for a
|
||||
dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file (the script might need some tweaks in that case,
|
||||
refer to the comments inside for help).
|
||||
|
||||
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
||||
|
||||
```bash
|
||||
export TASK_NAME=mrpc
|
||||
|
||||
python run_flax_glue.py \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name ${TASK_NAME} \
|
||||
@@ -69,7 +37,7 @@ python run_flax_glue.py \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--output_dir ${MODEL_DIR} \
|
||||
--output_dir ./$TASK_NAME/ \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import datasets
|
||||
@@ -34,7 +35,9 @@ from flax.jax_utils import replicate, unreplicate
|
||||
from flax.metrics import tensorboard
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository
|
||||
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
|
||||
from transformers.file_utils import get_full_repo_name
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -128,6 +131,10 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
@@ -141,6 +148,9 @@ def parse_args():
|
||||
extension = args.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
|
||||
if args.push_to_hub:
|
||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
@@ -267,6 +277,14 @@ def main():
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# Handle the repository creation
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
@@ -499,12 +517,10 @@ def main():
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(
|
||||
args.output_dir,
|
||||
params=params,
|
||||
push_to_hub=args.push_to_hub,
|
||||
commit_message=f"Saving weights and logs of epoch {epoch}",
|
||||
)
|
||||
model.save_pretrained(args.output_dir, params=params)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user