[Flax MLM] Refactor run mlm with optax (#11745)
* refactor * update * update * update * refactor run mlm * finalize * refactor more * fix typo * update * finish refactor * modify run mlm * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * small fixes * upload * upload * finish run mlm script Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
43891be19b
commit
00440e350f
129
examples/flax/language-modeling/README.md
Normal file
129
examples/flax/language-modeling/README.md
Normal file
@@ -0,0 +1,129 @@
|
||||
<!---
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Language model training examples
|
||||
|
||||
The following example showcases how to train a language model from scratch
|
||||
using the JAX/Flax backend.
|
||||
|
||||
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||
way which enables simple and efficient model parallelism.
|
||||
|
||||
## Masked language modeling
|
||||
|
||||
In the following, we demonstrate how to train a bi-directional transformer model
|
||||
using masked language modeling objective as introduced in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
|
||||
More specifically, we demonstrate how JAX/Flax can be leveraged
|
||||
to pre-train [**`roberta-base`**](https://huggingface.co/roberta-base)
|
||||
in Norwegian on a single TPUv3-8 pod.
|
||||
|
||||
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||
|
||||
Let's start by creating a folder to save the trained model and a symbolic link to the `run_mlm_flax.py` script.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./norwegian-roberta-base"
|
||||
mkdir -p ${MODEL_DIR}
|
||||
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
|
||||
```
|
||||
|
||||
### Train tokenizer
|
||||
|
||||
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
|
||||
The tokenizer is trained on the complete Norwegian dataset of OSCAR
|
||||
and consequently saved in `${MODEL_DIR}`
|
||||
This can take up to 10 minutes depending on your hardware ☕.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
|
||||
|
||||
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
|
||||
|
||||
# load dataset
|
||||
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
|
||||
|
||||
# Instantiate tokenizer
|
||||
tokenizer = ByteLevelBPETokenizer()
|
||||
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i: i + batch_size]["text"]
|
||||
|
||||
# Customized training
|
||||
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
||||
"<s>",
|
||||
"<pad>",
|
||||
"</s>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
])
|
||||
|
||||
# Save files to disk
|
||||
tokenizer.save(f"{model_dir}/tokenizer.json")
|
||||
```
|
||||
|
||||
### Create configuration
|
||||
|
||||
Next, we create the model's configuration file. This is as simple
|
||||
as loading and storing [`**roberta-base**`](https://huggingface.co/roberta-base)
|
||||
in the local model folder:
|
||||
|
||||
```python
|
||||
from transformers import RobertaConfig
|
||||
|
||||
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
|
||||
|
||||
config = RobertaConfig.from_pretrained("roberta-base")
|
||||
config.save_pretrained(model_dir)
|
||||
```
|
||||
|
||||
### Train model
|
||||
|
||||
Next we can run the example script to pretrain the model:
|
||||
|
||||
```bash
|
||||
./run_mlm_flax.py \
|
||||
--output_dir="./runs" \
|
||||
--model_type="roberta" \
|
||||
--config_name="${MODEL_DIR}" \
|
||||
--tokenizer_name="${MODEL_DIR}" \
|
||||
--dataset_name="oscar" \
|
||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||
--max_seq_length="128" \
|
||||
--weight_decay="0.01" \
|
||||
--per_device_train_batch_size="128" \
|
||||
--per_device_eval_batch_size="128" \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="1000" \
|
||||
--overwrite_output_dir \
|
||||
--pad_to_max_length \
|
||||
--num_train_epochs="18" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98"
|
||||
```
|
||||
|
||||
Training should converge at a loss and accuracy
|
||||
of 1.78 and 0.64 respectively after 18 epochs on a single TPUv3-8.
|
||||
This should take less than 18 hours.
|
||||
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg).
|
||||
|
||||
For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a
|
||||
look at [this TODO: (Patrick)]() google colab.
|
||||
|
||||
|
||||
## TODO(Patrick): Add comparison with PyTorch GPU/TPU
|
||||
Reference in New Issue
Block a user