[examples/flax] use Repository API for push_to_hub (#13672)
* use Repository for push_to_hub * update readme * update other flax scripts * update readme * update qa example * fix push_to_hub call * fix typo * fix more typos * update readme * use abosolute path to get repo name * fix glue script
This commit is contained in:
@@ -61,3 +61,14 @@ For a complete overview of models that are supported in JAX/Flax, please have a
|
|||||||
|
|
||||||
Over 3000 pretrained checkpoints are supported in JAX/Flax as of May 2021.
|
Over 3000 pretrained checkpoints are supported in JAX/Flax as of May 2021.
|
||||||
Click [here](https://huggingface.co/models?filter=jax) to see the full list on the 🤗 hub.
|
Click [here](https://huggingface.co/models?filter=jax) to see the full list on the 🤗 hub.
|
||||||
|
|
||||||
|
## Upload the trained/fine-tuned model to the Hub
|
||||||
|
|
||||||
|
All the example scripts support automatic upload of your final model to the [Model Hub](https://huggingface.co/models) by adding a `--push_to_hub` argument. It will then create a repository with your username slash the name of the folder you are using as `output_dir`. For instance, `"sgugger/test-mrpc"` if your username is `sgugger` and you are working in the folder `~/tmp/test-mrpc`.
|
||||||
|
|
||||||
|
To specify a given repository name, use the `--hub_model_id` argument. You will need to specify the whole repository name (including your username), for instance `--hub_model_id sgugger/finetuned-bert-mrpc`. To upload to an organization you are a member of, just use the name of that organization instead of your username: `--hub_model_id huggingface/finetuned-bert-mrpc`.
|
||||||
|
|
||||||
|
A few notes on this integration:
|
||||||
|
|
||||||
|
- you will need to be logged in to the Hugging Face website locally for it to work, the easiest way to achieve this is to run `huggingface-cli login` and then type your username and password when prompted. You can also pass along your authentication token with the `--hub_token` argument.
|
||||||
|
- the `output_dir` you pick will either need to be a new folder or a local clone of the distant repository you are using.
|
||||||
|
|||||||
@@ -33,32 +33,10 @@ in Norwegian on a single TPUv3-8 pod.
|
|||||||
|
|
||||||
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||||
|
|
||||||
Let's start by creating a model repository to save the trained model and logs.
|
To setup all relevant files for training, let's create a directory.
|
||||||
Here we call the model `"norwegian-roberta-base"`, but you can change the model name as you like.
|
|
||||||
|
|
||||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
|
||||||
you are logged in) or via the command line:
|
|
||||||
|
|
||||||
```
|
|
||||||
huggingface-cli repo create norwegian-roberta-base
|
|
||||||
```
|
|
||||||
|
|
||||||
Next we clone the model repository to add the tokenizer and model files.
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone https://huggingface.co/<your-username>/norwegian-roberta-base
|
|
||||||
```
|
|
||||||
|
|
||||||
To setup all relevant files for training, let's go into the cloned model directory.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd norwegian-roberta-base
|
mkdir ./norwegian-roberta-base
|
||||||
```
|
|
||||||
|
|
||||||
Next, let's add a symbolic link to the `run_mlm_flax.py`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Train tokenizer
|
### Train tokenizer
|
||||||
@@ -92,7 +70,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Save files to disk
|
# Save files to disk
|
||||||
tokenizer.save("./tokenizer.json")
|
tokenizer.save("./norwegian-roberta-base/tokenizer.json")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Create configuration
|
### Create configuration
|
||||||
@@ -105,7 +83,7 @@ in the local model folder:
|
|||||||
from transformers import RobertaConfig
|
from transformers import RobertaConfig
|
||||||
|
|
||||||
config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265)
|
config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265)
|
||||||
config.save_pretrained("./")
|
config.save_pretrained("./norwegian-roberta-base")
|
||||||
```
|
```
|
||||||
|
|
||||||
Great, we have set up our model repository. During training, we will automatically
|
Great, we have set up our model repository. During training, we will automatically
|
||||||
@@ -116,11 +94,11 @@ push the training logs and model weights to the repo.
|
|||||||
Next we can run the example script to pretrain the model:
|
Next we can run the example script to pretrain the model:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./run_mlm_flax.py \
|
python run_mlm_flax.py \
|
||||||
--output_dir="./" \
|
--output_dir="./norwegian-roberta-base" \
|
||||||
--model_type="roberta" \
|
--model_type="roberta" \
|
||||||
--config_name="./" \
|
--config_name="./norwegian-roberta-base" \
|
||||||
--tokenizer_name="./" \
|
--tokenizer_name="./norwegian-roberta-base" \
|
||||||
--dataset_name="oscar" \
|
--dataset_name="oscar" \
|
||||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||||
--max_seq_length="128" \
|
--max_seq_length="128" \
|
||||||
@@ -157,32 +135,11 @@ in Norwegian on a single TPUv3-8 pod.
|
|||||||
|
|
||||||
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||||
|
|
||||||
Let's start by creating a model repository to save the trained model and logs.
|
|
||||||
Here we call the model `"norwegian-gpt2"`, but you can change the model name as you like.
|
|
||||||
|
|
||||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
To setup all relevant files for training, let's create a directory.
|
||||||
you are logged in) or via the command line:
|
|
||||||
|
|
||||||
```
|
|
||||||
huggingface-cli repo create norwegian-gpt2
|
|
||||||
```
|
|
||||||
|
|
||||||
Next we clone the model repository to add the tokenizer and model files.
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone https://huggingface.co/<your-username>/norwegian-gpt2
|
|
||||||
```
|
|
||||||
|
|
||||||
To setup all relevant files for training, let's go into the cloned model directory.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd norwegian-gpt2
|
mkdir ./norwegian-gpt2
|
||||||
```
|
|
||||||
|
|
||||||
Next, let's add a symbolic link to the training script `run_clm_flax.py`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Train tokenizer
|
### Train tokenizer
|
||||||
@@ -216,7 +173,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency=
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Save files to disk
|
# Save files to disk
|
||||||
tokenizer.save("./tokenizer.json")
|
tokenizer.save("./norwegian-gpt2/tokenizer.json")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Create configuration
|
### Create configuration
|
||||||
@@ -229,7 +186,7 @@ in the local model folder:
|
|||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257)
|
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257)
|
||||||
config.save_pretrained("./")
|
config.save_pretrained("./norwegian-gpt2")
|
||||||
```
|
```
|
||||||
|
|
||||||
Great, we have set up our model repository. During training, we will now automatically
|
Great, we have set up our model repository. During training, we will now automatically
|
||||||
@@ -240,11 +197,11 @@ push the training logs and model weights to the repo.
|
|||||||
Finally, we can run the example script to pretrain the model:
|
Finally, we can run the example script to pretrain the model:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./run_clm_flax.py \
|
python run_clm_flax.py \
|
||||||
--output_dir="./" \
|
--output_dir="./norwegian-gpt2" \
|
||||||
--model_type="gpt2" \
|
--model_type="gpt2" \
|
||||||
--config_name="./" \
|
--config_name="./norwegian-gpt2" \
|
||||||
--tokenizer_name="./" \
|
--tokenizer_name="./norwegian-gpt2" \
|
||||||
--dataset_name="oscar" \
|
--dataset_name="oscar" \
|
||||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||||
--do_train --do_eval \
|
--do_train --do_eval \
|
||||||
@@ -282,30 +239,10 @@ The example script uses the 🤗 Datasets library. You can easily customize them
|
|||||||
Let's start by creating a model repository to save the trained model and logs.
|
Let's start by creating a model repository to save the trained model and logs.
|
||||||
Here we call the model `"norwegian-t5-base"`, but you can change the model name as you like.
|
Here we call the model `"norwegian-t5-base"`, but you can change the model name as you like.
|
||||||
|
|
||||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
To setup all relevant files for trairing, let's create a directory.
|
||||||
you are logged in) or via the command line:
|
|
||||||
|
|
||||||
```
|
|
||||||
huggingface-cli repo create norwegian-t5-base
|
|
||||||
```
|
|
||||||
|
|
||||||
Next we clone the model repository to add the tokenizer and model files.
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone https://huggingface.co/<your-username>/norwegian-t5-base
|
|
||||||
```
|
|
||||||
|
|
||||||
To setup all relevant files for trairing, let's go into the cloned model directory.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd norwegian-t5-base
|
cd ./norwegian-t5-base
|
||||||
```
|
|
||||||
|
|
||||||
Next, let's add a symbolic link to the `run_t5_mlm_flax.py` and `t5_tokenizer_model` scripts.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ln -s ~/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py run_t5_mlm_flax.py
|
|
||||||
ln -s ~/transformers/examples/flax/language-modeling/t5_tokenizer_model.py t5_tokenizer_model.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Train tokenizer
|
### Train tokenizer
|
||||||
@@ -351,7 +288,7 @@ tokenizer.train_from_iterator(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save files to disk
|
# Save files to disk
|
||||||
tokenizer.save("./tokenizer.json")
|
tokenizer.save("./norwegian-t5-base/tokenizer.json")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Create configuration
|
### Create configuration
|
||||||
@@ -364,7 +301,7 @@ in the local model folder:
|
|||||||
from transformers import T5Config
|
from transformers import T5Config
|
||||||
|
|
||||||
config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size())
|
config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size())
|
||||||
config.save_pretrained("./")
|
config.save_pretrained("./norwegian-t5-base")
|
||||||
```
|
```
|
||||||
|
|
||||||
Great, we have set up our model repository. During training, we will automatically
|
Great, we have set up our model repository. During training, we will automatically
|
||||||
@@ -375,11 +312,11 @@ push the training logs and model weights to the repo.
|
|||||||
Next we can run the example script to pretrain the model:
|
Next we can run the example script to pretrain the model:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./run_t5_mlm_flax.py \
|
python run_t5_mlm_flax.py \
|
||||||
--output_dir="./" \
|
--output_dir="./norwegian-t5-base" \
|
||||||
--model_type="t5" \
|
--model_type="t5" \
|
||||||
--config_name="./" \
|
--config_name="./norwegian-t5-base" \
|
||||||
--tokenizer_name="./" \
|
--tokenizer_name="./norwegian-t5-base" \
|
||||||
--dataset_name="oscar" \
|
--dataset_name="oscar" \
|
||||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||||
--max_seq_length="512" \
|
--max_seq_length="512" \
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from flax import jax_utils, traverse_util
|
|||||||
from flax.jax_utils import unreplicate
|
from flax.jax_utils import unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from huggingface_hub import Repository
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
@@ -54,6 +55,7 @@ from transformers import (
|
|||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
from transformers.file_utils import get_full_repo_name
|
||||||
from transformers.testing_utils import CaptureLogger
|
from transformers.testing_utils import CaptureLogger
|
||||||
|
|
||||||
|
|
||||||
@@ -275,6 +277,16 @@ def main():
|
|||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
set_seed(training_args.seed)
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
if training_args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(
|
||||||
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repo_name = training_args.hub_model_id
|
||||||
|
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
@@ -654,12 +666,10 @@ def main():
|
|||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(unreplicate(state.params))
|
params = jax.device_get(unreplicate(state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
training_args.output_dir,
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
params=params,
|
if training_args.push_to_hub:
|
||||||
push_to_hub=training_args.push_to_hub,
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
commit_message=f"Saving weights and logs of step {cur_step}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ import optax
|
|||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
|
from huggingface_hub import Repository
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
@@ -54,6 +55,7 @@ from transformers import (
|
|||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
from transformers.file_utils import get_full_repo_name
|
||||||
|
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||||
@@ -308,6 +310,16 @@ if __name__ == "__main__":
|
|||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
set_seed(training_args.seed)
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
if training_args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(
|
||||||
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repo_name = training_args.hub_model_id
|
||||||
|
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
@@ -683,9 +695,7 @@ if __name__ == "__main__":
|
|||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
training_args.output_dir,
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
params=params,
|
if training_args.push_to_hub:
|
||||||
push_to_hub=training_args.push_to_hub,
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
commit_message=f"Saving weights and logs of step {cur_step}",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ import optax
|
|||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
|
from huggingface_hub import Repository
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
@@ -52,6 +53,7 @@ from transformers import (
|
|||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
from transformers.file_utils import get_full_repo_name
|
||||||
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
||||||
|
|
||||||
|
|
||||||
@@ -438,6 +440,16 @@ if __name__ == "__main__":
|
|||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
set_seed(training_args.seed)
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
if training_args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(
|
||||||
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repo_name = training_args.hub_model_id
|
||||||
|
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
@@ -791,9 +803,7 @@ if __name__ == "__main__":
|
|||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
training_args.output_dir,
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
params=params,
|
if training_args.push_to_hub:
|
||||||
push_to_hub=training_args.push_to_hub,
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
commit_message=f"Saving weights and logs of step {cur_step}",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -26,31 +26,6 @@ of the script.
|
|||||||
|
|
||||||
The following example fine-tunes BERT on SQuAD:
|
The following example fine-tunes BERT on SQuAD:
|
||||||
|
|
||||||
To begin with it is recommended to create a model repository to save the trained model and logs.
|
|
||||||
Here we call the model `"bert-qa-squad-test"`, but you can change the model name as you like.
|
|
||||||
|
|
||||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
|
||||||
you are logged in) or via the command line:
|
|
||||||
|
|
||||||
```
|
|
||||||
huggingface-cli repo create bert-qa-squad-test
|
|
||||||
```
|
|
||||||
|
|
||||||
Next we clone the model repository to add the tokenizer and model files.
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone https://huggingface.co/<your-username>/bert-qa-squad-test
|
|
||||||
```
|
|
||||||
|
|
||||||
Great, we have set up our model repository. During training, we will automatically
|
|
||||||
push the training logs and model weights to the repo.
|
|
||||||
|
|
||||||
Next, let's add a symbolic link to the `run_qa.py`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export MODEL_DIR="./bert-qa-squad-test"
|
|
||||||
ln -s ~/transformers/examples/flax/question-answering/run_qa.py run_qa.py
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python run_qa.py \
|
python run_qa.py \
|
||||||
@@ -63,7 +38,7 @@ python run_qa.py \
|
|||||||
--learning_rate 3e-5 \
|
--learning_rate 3e-5 \
|
||||||
--num_train_epochs 2 \
|
--num_train_epochs 2 \
|
||||||
--per_device_train_batch_size 12 \
|
--per_device_train_batch_size 12 \
|
||||||
--output_dir ${MODEL_DIR} \
|
--output_dir ./bert-qa-squad \
|
||||||
--eval_steps 1000 \
|
--eval_steps 1000 \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
```
|
```
|
||||||
@@ -101,8 +76,9 @@ python run_qa.py \
|
|||||||
--num_train_epochs 2 \
|
--num_train_epochs 2 \
|
||||||
--max_seq_length 384 \
|
--max_seq_length 384 \
|
||||||
--doc_stride 128 \
|
--doc_stride 128 \
|
||||||
--output_dir /tmp/wwm_uncased_finetuned_squad/ \
|
--output_dir ./wwm_uncased_finetuned_squad/ \
|
||||||
--eval_steps 1000
|
--eval_steps 1000 \
|
||||||
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
Training with the previously defined hyper-parameters yields the following results:
|
Training with the previously defined hyper-parameters yields the following results:
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -41,6 +42,7 @@ from flax.jax_utils import replicate, unreplicate
|
|||||||
from flax.metrics import tensorboard
|
from flax.metrics import tensorboard
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
|
from huggingface_hub import Repository
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -50,6 +52,7 @@ from transformers import (
|
|||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
|
from transformers.file_utils import get_full_repo_name
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from utils_qa import postprocess_qa_predictions
|
from utils_qa import postprocess_qa_predictions
|
||||||
|
|
||||||
@@ -359,6 +362,16 @@ def main():
|
|||||||
transformers.utils.logging.set_verbosity_error()
|
transformers.utils.logging.set_verbosity_error()
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
if training_args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(
|
||||||
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repo_name = training_args.hub_model_id
|
||||||
|
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# region Load Data
|
# region Load Data
|
||||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||||
@@ -891,12 +904,10 @@ def main():
|
|||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(unreplicate(state.params))
|
params = jax.device_get(unreplicate(state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
training_args.output_dir,
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
params=params,
|
if training_args.push_to_hub:
|
||||||
push_to_hub=training_args.push_to_hub,
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
commit_message=f"Saving weights and logs of step {cur_step}",
|
|
||||||
)
|
|
||||||
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|||||||
@@ -11,43 +11,12 @@ way which enables simple and efficient model parallelism.
|
|||||||
|
|
||||||
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
|
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
|
||||||
|
|
||||||
Let's start by creating a model repository to save the trained model and logs.
|
|
||||||
Here we call the model `"bart-base-xsum"`, but you can change the model name as you like.
|
|
||||||
|
|
||||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
|
||||||
you are logged in) or via the command line:
|
|
||||||
|
|
||||||
```
|
|
||||||
huggingface-cli repo create bart-base-xsum
|
|
||||||
```
|
|
||||||
Next we clone the model repository to add the tokenizer and model files.
|
|
||||||
```
|
|
||||||
git clone https://huggingface.co/<your-username>/bart-base-xsum
|
|
||||||
```
|
|
||||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
|
||||||
track them. You can run the following command inside your model repo to do so.
|
|
||||||
|
|
||||||
```
|
|
||||||
cd bart-base-xsum
|
|
||||||
git lfs track "*tfevents*"
|
|
||||||
```
|
|
||||||
|
|
||||||
Great, we have set up our model repository. During training, we will automatically
|
|
||||||
push the training logs and model weights to the repo.
|
|
||||||
|
|
||||||
Next, let's add a symbolic link to the `run_summarization_flax.py`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export MODEL_DIR="./bart-base-xsum"
|
|
||||||
ln -s ~/transformers/examples/flax/summarization/run_summarization_flax.py run_summarization_flax.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Train the model
|
### Train the model
|
||||||
Next we can run the example script to train the model:
|
Next we can run the example script to train the model:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python run_summarization_flax.py \
|
python run_summarization_flax.py \
|
||||||
--output_dir ${MODEL_DIR} \
|
--output_dir ./bart-base-xsum \
|
||||||
--model_name_or_path facebook/bart-base \
|
--model_name_or_path facebook/bart-base \
|
||||||
--tokenizer_name facebook/bart-base \
|
--tokenizer_name facebook/bart-base \
|
||||||
--dataset_name="xsum" \
|
--dataset_name="xsum" \
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from flax import jax_utils, traverse_util
|
|||||||
from flax.jax_utils import unreplicate
|
from flax.jax_utils import unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from huggingface_hub import Repository
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
@@ -52,7 +53,7 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import is_offline_mode
|
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -333,6 +334,16 @@ def main():
|
|||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
if training_args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(
|
||||||
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repo_name = training_args.hub_model_id
|
||||||
|
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
@@ -800,12 +811,10 @@ def main():
|
|||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
training_args.output_dir,
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
params=params,
|
if training_args.push_to_hub:
|
||||||
push_to_hub=training_args.push_to_hub,
|
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||||
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -21,47 +21,15 @@ limitations under the License.
|
|||||||
Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py).
|
Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py).
|
||||||
|
|
||||||
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
||||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models).
|
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) and can also be used for a
|
||||||
|
dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file (the script might need some tweaks in that case,
|
||||||
To begin with it is recommended to create a model repository to save the trained model and logs.
|
refer to the comments inside for help).
|
||||||
Here we call the model `"bert-glue-mrpc-test"`, but you can change the model name as you like.
|
|
||||||
|
|
||||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
|
||||||
you are logged in) or via the command line:
|
|
||||||
|
|
||||||
```
|
|
||||||
huggingface-cli repo create bert-glue-mrpc-test
|
|
||||||
```
|
|
||||||
|
|
||||||
Next we clone the model repository to add the tokenizer and model files.
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone https://huggingface.co/<your-username>/bert-glue-mrpc-test
|
|
||||||
```
|
|
||||||
|
|
||||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
|
||||||
track them. You can run the following command inside your model repo to do so.
|
|
||||||
|
|
||||||
```
|
|
||||||
cd bert-glue-mrpc-test
|
|
||||||
git lfs track "*tfevents*"
|
|
||||||
```
|
|
||||||
|
|
||||||
Great, we have set up our model repository. During training, we will automatically
|
|
||||||
push the training logs and model weights to the repo.
|
|
||||||
|
|
||||||
Next, let's add a symbolic link to the `run_flax_glue.py`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export TASK_NAME=mrpc
|
|
||||||
export MODEL_DIR="./bert-glue-mrpc-test"
|
|
||||||
ln -s ~/transformers/examples/flax/text-classification/run_flax_glue.py run_flax_glue.py
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
export TASK_NAME=mrpc
|
||||||
|
|
||||||
python run_flax_glue.py \
|
python run_flax_glue.py \
|
||||||
--model_name_or_path bert-base-cased \
|
--model_name_or_path bert-base-cased \
|
||||||
--task_name ${TASK_NAME} \
|
--task_name ${TASK_NAME} \
|
||||||
@@ -69,7 +37,7 @@ python run_flax_glue.py \
|
|||||||
--learning_rate 2e-5 \
|
--learning_rate 2e-5 \
|
||||||
--num_train_epochs 3 \
|
--num_train_epochs 3 \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
--output_dir ${MODEL_DIR} \
|
--output_dir ./$TASK_NAME/ \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Tuple
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -34,7 +35,9 @@ from flax.jax_utils import replicate, unreplicate
|
|||||||
from flax.metrics import tensorboard
|
from flax.metrics import tensorboard
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
|
from huggingface_hub import Repository
|
||||||
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
|
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
|
||||||
|
from transformers.file_utils import get_full_repo_name
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -128,6 +131,10 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
|
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||||
|
)
|
||||||
|
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
@@ -141,6 +148,9 @@ def parse_args():
|
|||||||
extension = args.validation_file.split(".")[-1]
|
extension = args.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -267,6 +277,14 @@ def main():
|
|||||||
datasets.utils.logging.set_verbosity_error()
|
datasets.utils.logging.set_verbosity_error()
|
||||||
transformers.utils.logging.set_verbosity_error()
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if args.push_to_hub:
|
||||||
|
if args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token)
|
||||||
|
else:
|
||||||
|
repo_name = args.hub_model_id
|
||||||
|
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||||
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
|
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
|
|
||||||
@@ -499,12 +517,10 @@ def main():
|
|||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(args.output_dir, params=params)
|
||||||
args.output_dir,
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
params=params,
|
if args.push_to_hub:
|
||||||
push_to_hub=args.push_to_hub,
|
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||||
commit_message=f"Saving weights and logs of epoch {epoch}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -22,31 +22,6 @@ It will either run on a datasets hosted on our hub or with your own text files f
|
|||||||
|
|
||||||
The following example fine-tunes BERT on CoNLL-2003:
|
The following example fine-tunes BERT on CoNLL-2003:
|
||||||
|
|
||||||
To begin with it is recommended to create a model repository to save the trained model and logs.
|
|
||||||
Here we call the model `"bert-ner-conll2003-test"`, but you can change the model name as you like.
|
|
||||||
|
|
||||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
|
||||||
you are logged in) or via the command line:
|
|
||||||
|
|
||||||
```
|
|
||||||
huggingface-cli repo create bert-ner-conll2003-test
|
|
||||||
```
|
|
||||||
|
|
||||||
Next we clone the model repository to add the tokenizer and model files.
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone https://huggingface.co/<your-username>/bert-ner-conll2003-test
|
|
||||||
```
|
|
||||||
|
|
||||||
Great, we have set up our model repository. During training, we will automatically
|
|
||||||
push the training logs and model weights to the repo.
|
|
||||||
|
|
||||||
Next, let's add a symbolic link to the `run_flax_ner.py`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export MODEL_DIR="./bert-ner-conll2003-test"
|
|
||||||
ln -s ~/transformers/examples/flax/token-classification/run_flax_ner.py run_flax_ner.py
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python run_flax_ner.py \
|
python run_flax_ner.py \
|
||||||
@@ -56,7 +31,7 @@ python run_flax_ner.py \
|
|||||||
--learning_rate 2e-5 \
|
--learning_rate 2e-5 \
|
||||||
--num_train_epochs 3 \
|
--num_train_epochs 3 \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
--output_dir ${MODEL_DIR} \
|
--output_dir ./bert-ner-conll2003 \
|
||||||
--eval_steps 300 \
|
--eval_steps 300 \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -37,6 +38,7 @@ from flax.jax_utils import replicate, unreplicate
|
|||||||
from flax.metrics import tensorboard
|
from flax.metrics import tensorboard
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
|
from huggingface_hub import Repository
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -44,6 +46,7 @@ from transformers import (
|
|||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
|
from transformers.file_utils import get_full_repo_name
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
@@ -304,6 +307,16 @@ def main():
|
|||||||
datasets.utils.logging.set_verbosity_error()
|
datasets.utils.logging.set_verbosity_error()
|
||||||
transformers.utils.logging.set_verbosity_error()
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
if training_args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(
|
||||||
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repo_name = training_args.hub_model_id
|
||||||
|
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||||
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
|
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
|
||||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
@@ -656,12 +669,10 @@ def main():
|
|||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(unreplicate(state.params))
|
params = jax.device_get(unreplicate(state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
training_args.output_dir,
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
params=params,
|
if training_args.push_to_hub:
|
||||||
push_to_hub=training_args.push_to_hub,
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
commit_message=f"Saving weights and logs of step {cur_step}",
|
|
||||||
)
|
|
||||||
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,37 +25,6 @@ way which enables simple and efficient model parallelism.
|
|||||||
|
|
||||||
In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.
|
In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.
|
||||||
|
|
||||||
Let's start by creating a model repository to save the trained model and logs.
|
|
||||||
Here we call the model `"vit-base-patch16-imagenette"`, but you can change the model name as you like.
|
|
||||||
|
|
||||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
|
||||||
you are logged in) or via the command line:
|
|
||||||
|
|
||||||
```
|
|
||||||
huggingface-cli repo create vit-base-patch16-imagenette
|
|
||||||
```
|
|
||||||
Next we clone the model repository to add the tokenizer and model files.
|
|
||||||
```
|
|
||||||
git clone https://huggingface.co/<your-username>/vit-base-patch16-imagenette
|
|
||||||
```
|
|
||||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
|
||||||
track them. You can run the following command inside your model repo to do so.
|
|
||||||
|
|
||||||
```
|
|
||||||
cd vit-base-patch16-imagenette
|
|
||||||
git lfs track "*tfevents*"
|
|
||||||
```
|
|
||||||
|
|
||||||
Great, we have set up our model repository. During training, we will automatically
|
|
||||||
push the training logs and model weights to the repo.
|
|
||||||
|
|
||||||
Next, let's add a symbolic link to the `run_image_classification_flax.py`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export MODEL_DIR="./vit-base-patch16-imagenette
|
|
||||||
ln -s ~/transformers/examples/flax/summarization/run_image_classification_flax.py run_image_classification_flax.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Prepare the dataset
|
## Prepare the dataset
|
||||||
|
|
||||||
We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
|
We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
|
||||||
@@ -86,7 +55,7 @@ Next we can run the example script to fine-tune the model:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python run_image_classification.py \
|
python run_image_classification.py \
|
||||||
--output_dir ${MODEL_DIR} \
|
--output_dir ./vit-base-patch16-imagenette \
|
||||||
--model_name_or_path google/vit-base-patch16-224-in21k \
|
--model_name_or_path google/vit-base-patch16-224-in21k \
|
||||||
--train_dir="imagenette2/train" \
|
--train_dir="imagenette2/train" \
|
||||||
--validation_dir="imagenette2/val" \
|
--validation_dir="imagenette2/val" \
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from flax import jax_utils
|
|||||||
from flax.jax_utils import unreplicate
|
from flax.jax_utils import unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from huggingface_hub import Repository
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
@@ -52,6 +53,7 @@ from transformers import (
|
|||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
from transformers.file_utils import get_full_repo_name
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -205,6 +207,16 @@ def main():
|
|||||||
# set seed for random transforms and torch dataloaders
|
# set seed for random transforms and torch dataloaders
|
||||||
set_seed(training_args.seed)
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
if training_args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(
|
||||||
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repo_name = training_args.hub_model_id
|
||||||
|
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# Initialize datasets and pre-processing transforms
|
# Initialize datasets and pre-processing transforms
|
||||||
# We use torchvision here for faster pre-processing
|
# We use torchvision here for faster pre-processing
|
||||||
# Note that here we are using some default pre-processing, for maximum accuray
|
# Note that here we are using some default pre-processing, for maximum accuray
|
||||||
@@ -455,12 +467,9 @@ def main():
|
|||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
training_args.output_dir,
|
if training_args.push_to_hub:
|
||||||
params=params,
|
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||||
push_to_hub=training_args.push_to_hub,
|
|
||||||
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user