add readme for flax clm (#12111)
* add readme for flax clm * use section link for tokenizer * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update metrics Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -125,6 +125,68 @@ Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experi
|
|||||||
For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a
|
For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a
|
||||||
look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/masked_language_modeling_flax.ipynb) google colab.
|
look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/masked_language_modeling_flax.ipynb) google colab.
|
||||||
|
|
||||||
|
## Causal language modeling
|
||||||
|
|
||||||
|
In the following, we demonstrate how to train an auto-regressive causal transformer model
|
||||||
|
in JAX/Flax.
|
||||||
|
More specifically, we pretrain a randomely initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian on a single TPUv3-8.
|
||||||
|
to pre-train 124M [**`gpt2`**](https://huggingface.co/gpt2)
|
||||||
|
in Norwegian on a single TPUv3-8 pod.
|
||||||
|
|
||||||
|
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||||
|
|
||||||
|
Let's start by creating a folder to save the trained model and a symbolic link to the `run_clm_flax.py` script.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MODEL_DIR="./norwegian-gpt2"
|
||||||
|
mkdir -p ${MODEL_DIR}
|
||||||
|
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Next, we'll follow the same steps as above in [Train tokenizer](#train-tokenizer) to train the tokenizer.
|
||||||
|
|
||||||
|
### Create configuration
|
||||||
|
|
||||||
|
Next, we create the model's configuration file. This is as simple
|
||||||
|
as loading and storing [`**gpt2**`](https://huggingface.co/gpt2)
|
||||||
|
in the local model folder:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import GPT2Config
|
||||||
|
|
||||||
|
model_dir = "./norwegian-gpt2" # ${MODEL_DIR}
|
||||||
|
|
||||||
|
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0)
|
||||||
|
config.save_pretrained(model_dir)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Train model
|
||||||
|
|
||||||
|
Next we can run the example script to pretrain the model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./run_clm_flax.py \
|
||||||
|
--output_dir="./runs" \
|
||||||
|
--model_type="gpt2" \
|
||||||
|
--config_name="${MODEL_DIR}" \
|
||||||
|
--tokenizer_name="${MODEL_DIR}" \
|
||||||
|
--dataset_name="oscar" \
|
||||||
|
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||||
|
--do_train --do_eval \
|
||||||
|
--block_size="512" \
|
||||||
|
--per_device_train_batch_size="64" \
|
||||||
|
--per_device_eval_batch_size="64" \
|
||||||
|
--learning_rate="5e-3" --warmup_steps="1000" \
|
||||||
|
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--num_train_epochs="20" \
|
||||||
|
```
|
||||||
|
|
||||||
|
Training should converge at a loss and perplexity
|
||||||
|
of 3.24 and 25.72 respectively after 20 epochs on a single TPUv3-8.
|
||||||
|
This should take less than ~21 hours.
|
||||||
|
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA).
|
||||||
|
|
||||||
|
|
||||||
## Runtime evaluation
|
## Runtime evaluation
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user