[examples/flax] use Repository API for push_to_hub (#13672)

* use Repository for push_to_hub

* update readme

* update other flax scripts

* update readme

* update qa example

* fix push_to_hub call

* fix typo

* fix more typos

* update readme

* use abosolute path to get repo name

* fix glue script
This commit is contained in:
Suraj Patil
2021-09-30 16:38:07 +05:30
committed by GitHub
parent b90096fe14
commit 7db2a79b38
15 changed files with 183 additions and 292 deletions

View File

@@ -25,37 +25,6 @@ way which enables simple and efficient model parallelism.
In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"vit-base-patch16-imagenette"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create vit-base-patch16-imagenette
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/vit-base-patch16-imagenette
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd vit-base-patch16-imagenette
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_image_classification_flax.py`.
```bash
export MODEL_DIR="./vit-base-patch16-imagenette
ln -s ~/transformers/examples/flax/summarization/run_image_classification_flax.py run_image_classification_flax.py
```
## Prepare the dataset
We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
@@ -86,7 +55,7 @@ Next we can run the example script to fine-tune the model:
```bash
python run_image_classification.py \
--output_dir ${MODEL_DIR} \
--output_dir ./vit-base-patch16-imagenette \
--model_name_or_path google/vit-base-patch16-224-in21k \
--train_dir="imagenette2/train" \
--validation_dir="imagenette2/val" \

View File

@@ -42,6 +42,7 @@ from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
@@ -52,6 +53,7 @@ from transformers import (
is_tensorboard_available,
set_seed,
)
from transformers.file_utils import get_full_repo_name
logger = logging.getLogger(__name__)
@@ -205,6 +207,16 @@ def main():
# set seed for random transforms and torch dataloaders
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing
# Note that here we are using some default pre-processing, for maximum accuray
@@ -455,12 +467,9 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
model.save_pretrained(training_args.output_dir, params=params)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
if __name__ == "__main__":