From d36fce8237eab3af6d717da8530d6edafd045e1e Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 14 Jun 2021 15:03:55 +0530 Subject: [PATCH] add readme for flax clm (#12111) * add readme for flax clm * use section link for tokenizer * Apply suggestions from code review Co-authored-by: Patrick von Platen * update metrics Co-authored-by: Patrick von Platen --- examples/flax/language-modeling/README.md | 62 +++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md index a9fa0df1f8..34d5cae140 100644 --- a/examples/flax/language-modeling/README.md +++ b/examples/flax/language-modeling/README.md @@ -125,6 +125,68 @@ Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experi For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/masked_language_modeling_flax.ipynb) google colab. +## Causal language modeling + +In the following, we demonstrate how to train an auto-regressive causal transformer model +in JAX/Flax. +More specifically, we pretrain a randomely initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian on a single TPUv3-8. +to pre-train 124M [**`gpt2`**](https://huggingface.co/gpt2) +in Norwegian on a single TPUv3-8 pod. + +The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. + +Let's start by creating a folder to save the trained model and a symbolic link to the `run_clm_flax.py` script. + +```bash +export MODEL_DIR="./norwegian-gpt2" +mkdir -p ${MODEL_DIR} +ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py +``` + +Next, we'll follow the same steps as above in [Train tokenizer](#train-tokenizer) to train the tokenizer. + +### Create configuration + +Next, we create the model's configuration file. This is as simple +as loading and storing [`**gpt2**`](https://huggingface.co/gpt2) +in the local model folder: + +```python +from transformers import GPT2Config + +model_dir = "./norwegian-gpt2" # ${MODEL_DIR} + +config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0) +config.save_pretrained(model_dir) +``` + +### Train model + +Next we can run the example script to pretrain the model: + +```bash +./run_clm_flax.py \ + --output_dir="./runs" \ + --model_type="gpt2" \ + --config_name="${MODEL_DIR}" \ + --tokenizer_name="${MODEL_DIR}" \ + --dataset_name="oscar" \ + --dataset_config_name="unshuffled_deduplicated_no" \ + --do_train --do_eval \ + --block_size="512" \ + --per_device_train_batch_size="64" \ + --per_device_eval_batch_size="64" \ + --learning_rate="5e-3" --warmup_steps="1000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="20" \ +``` + +Training should converge at a loss and perplexity +of 3.24 and 25.72 respectively after 20 epochs on a single TPUv3-8. +This should take less than ~21 hours. +Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA). + ## Runtime evaluation