[Flax MLM] Refactor run mlm with optax (#11745)
* refactor * update * update * update * refactor run mlm * finalize * refactor more * fix typo * update * finish refactor * modify run mlm * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * small fixes * upload * upload * finish run mlm script Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
43891be19b
commit
00440e350f
129
examples/flax/language-modeling/README.md
Normal file
129
examples/flax/language-modeling/README.md
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
<!---
|
||||||
|
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Language model training examples
|
||||||
|
|
||||||
|
The following example showcases how to train a language model from scratch
|
||||||
|
using the JAX/Flax backend.
|
||||||
|
|
||||||
|
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||||
|
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||||
|
way which enables simple and efficient model parallelism.
|
||||||
|
|
||||||
|
## Masked language modeling
|
||||||
|
|
||||||
|
In the following, we demonstrate how to train a bi-directional transformer model
|
||||||
|
using masked language modeling objective as introduced in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
|
||||||
|
More specifically, we demonstrate how JAX/Flax can be leveraged
|
||||||
|
to pre-train [**`roberta-base`**](https://huggingface.co/roberta-base)
|
||||||
|
in Norwegian on a single TPUv3-8 pod.
|
||||||
|
|
||||||
|
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||||
|
|
||||||
|
Let's start by creating a folder to save the trained model and a symbolic link to the `run_mlm_flax.py` script.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MODEL_DIR="./norwegian-roberta-base"
|
||||||
|
mkdir -p ${MODEL_DIR}
|
||||||
|
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Train tokenizer
|
||||||
|
|
||||||
|
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
|
||||||
|
The tokenizer is trained on the complete Norwegian dataset of OSCAR
|
||||||
|
and consequently saved in `${MODEL_DIR}`
|
||||||
|
This can take up to 10 minutes depending on your hardware ☕.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datasets import load_dataset
|
||||||
|
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
|
||||||
|
|
||||||
|
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
|
||||||
|
|
||||||
|
# load dataset
|
||||||
|
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
|
||||||
|
|
||||||
|
# Instantiate tokenizer
|
||||||
|
tokenizer = ByteLevelBPETokenizer()
|
||||||
|
|
||||||
|
def batch_iterator(batch_size=1000):
|
||||||
|
for i in range(0, len(dataset), batch_size):
|
||||||
|
yield dataset[i: i + batch_size]["text"]
|
||||||
|
|
||||||
|
# Customized training
|
||||||
|
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
||||||
|
"<s>",
|
||||||
|
"<pad>",
|
||||||
|
"</s>",
|
||||||
|
"<unk>",
|
||||||
|
"<mask>",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Save files to disk
|
||||||
|
tokenizer.save(f"{model_dir}/tokenizer.json")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create configuration
|
||||||
|
|
||||||
|
Next, we create the model's configuration file. This is as simple
|
||||||
|
as loading and storing [`**roberta-base**`](https://huggingface.co/roberta-base)
|
||||||
|
in the local model folder:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import RobertaConfig
|
||||||
|
|
||||||
|
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
|
||||||
|
|
||||||
|
config = RobertaConfig.from_pretrained("roberta-base")
|
||||||
|
config.save_pretrained(model_dir)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Train model
|
||||||
|
|
||||||
|
Next we can run the example script to pretrain the model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./run_mlm_flax.py \
|
||||||
|
--output_dir="./runs" \
|
||||||
|
--model_type="roberta" \
|
||||||
|
--config_name="${MODEL_DIR}" \
|
||||||
|
--tokenizer_name="${MODEL_DIR}" \
|
||||||
|
--dataset_name="oscar" \
|
||||||
|
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||||
|
--max_seq_length="128" \
|
||||||
|
--weight_decay="0.01" \
|
||||||
|
--per_device_train_batch_size="128" \
|
||||||
|
--per_device_eval_batch_size="128" \
|
||||||
|
--learning_rate="3e-4" \
|
||||||
|
--warmup_steps="1000" \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--pad_to_max_length \
|
||||||
|
--num_train_epochs="18" \
|
||||||
|
--adam_beta1="0.9" \
|
||||||
|
--adam_beta2="0.98"
|
||||||
|
```
|
||||||
|
|
||||||
|
Training should converge at a loss and accuracy
|
||||||
|
of 1.78 and 0.64 respectively after 18 epochs on a single TPUv3-8.
|
||||||
|
This should take less than 18 hours.
|
||||||
|
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg).
|
||||||
|
|
||||||
|
For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a
|
||||||
|
look at [this TODO: (Patrick)]() google colab.
|
||||||
|
|
||||||
|
|
||||||
|
## TODO(Patrick): Add comparison with PyTorch GPU/TPU
|
||||||
4
examples/flax/language-modeling/requirements.txt
Normal file
4
examples/flax/language-modeling/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
datasets >= 1.1.3
|
||||||
|
jax>=0.2.8
|
||||||
|
jaxlib>=0.1.59
|
||||||
|
flax>=0.3.4
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -23,6 +23,7 @@ https://huggingface.co/models?filter=masked-lm
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||||
@@ -35,11 +36,10 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import optax
|
||||||
from flax import jax_utils
|
from flax import jax_utils
|
||||||
from flax.optim import Adam
|
from flax.training import train_state
|
||||||
from flax.training import common_utils
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from flax.training.common_utils import get_metrics
|
|
||||||
from jax.nn import log_softmax
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
@@ -269,167 +269,30 @@ class FlaxDataCollatorForLanguageModeling:
|
|||||||
return inputs, labels
|
return inputs, labels
|
||||||
|
|
||||||
|
|
||||||
def create_learning_rate_scheduler(
|
|
||||||
factors="constant * linear_warmup * rsqrt_decay",
|
|
||||||
base_learning_rate=0.5,
|
|
||||||
warmup_steps=1000,
|
|
||||||
decay_factor=0.5,
|
|
||||||
steps_per_decay=20000,
|
|
||||||
steps_per_cycle=100000,
|
|
||||||
):
|
|
||||||
"""Creates learning rate schedule.
|
|
||||||
Interprets factors in the factors string which can consist of:
|
|
||||||
* constant: interpreted as the constant value,
|
|
||||||
* linear_warmup: interpreted as linear warmup until warmup_steps,
|
|
||||||
* rsqrt_decay: divide by square root of max(step, warmup_steps)
|
|
||||||
* rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
|
|
||||||
* decay_every: Every k steps decay the learning rate by decay_factor.
|
|
||||||
* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
|
|
||||||
Args:
|
|
||||||
factors: string, factors separated by "*" that defines the schedule.
|
|
||||||
base_learning_rate: float, the starting constant for the lr schedule.
|
|
||||||
warmup_steps: int, how many steps to warm up for in the warmup schedule.
|
|
||||||
decay_factor: float, the amount to decay the learning rate by.
|
|
||||||
steps_per_decay: int, how often to decay the learning rate.
|
|
||||||
steps_per_cycle: int, steps per cycle when using cosine decay.
|
|
||||||
Returns:
|
|
||||||
a function learning_rate(step): float -> {"learning_rate": float}, the
|
|
||||||
step-dependent lr.
|
|
||||||
"""
|
|
||||||
factors = [n.strip() for n in factors.split("*")]
|
|
||||||
|
|
||||||
def step_fn(step):
|
|
||||||
"""Step to learning rate function."""
|
|
||||||
ret = 1.0
|
|
||||||
for name in factors:
|
|
||||||
if name == "constant":
|
|
||||||
ret *= base_learning_rate
|
|
||||||
elif name == "linear_warmup":
|
|
||||||
ret *= jnp.minimum(1.0, step / warmup_steps)
|
|
||||||
elif name == "rsqrt_decay":
|
|
||||||
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
|
|
||||||
elif name == "rsqrt_normalized_decay":
|
|
||||||
ret *= jnp.sqrt(warmup_steps)
|
|
||||||
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
|
|
||||||
elif name == "decay_every":
|
|
||||||
ret *= decay_factor ** (step // steps_per_decay)
|
|
||||||
elif name == "cosine_decay":
|
|
||||||
progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle))
|
|
||||||
ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown factor {name}.")
|
|
||||||
return jnp.asarray(ret, dtype=jnp.float32)
|
|
||||||
|
|
||||||
return step_fn
|
|
||||||
|
|
||||||
|
|
||||||
def compute_metrics(logits, labels, weights, label_smoothing=0.0):
|
|
||||||
"""Compute summary metrics."""
|
|
||||||
loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing)
|
|
||||||
acc, _ = accuracy(logits, labels, weights)
|
|
||||||
metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer}
|
|
||||||
metrics = jax.lax.psum(metrics, axis_name="batch")
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
def accuracy(logits, targets, weights=None):
|
|
||||||
"""Compute weighted accuracy for log probs and targets.
|
|
||||||
Args:
|
|
||||||
logits: [batch, length, num_classes] float array.
|
|
||||||
targets: categorical targets [batch, length] int array.
|
|
||||||
weights: None or array of shape [batch, length]
|
|
||||||
Returns:
|
|
||||||
Tuple of scalar loss and batch normalizing factor.
|
|
||||||
"""
|
|
||||||
if logits.ndim != targets.ndim + 1:
|
|
||||||
raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets")
|
|
||||||
|
|
||||||
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
|
|
||||||
loss *= weights
|
|
||||||
|
|
||||||
return loss.sum(), weights.sum()
|
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0):
|
|
||||||
"""Compute cross entropy and entropy for log probs and targets.
|
|
||||||
Args:
|
|
||||||
logits: [batch, length, num_classes] float array.
|
|
||||||
targets: categorical targets [batch, length] int array.
|
|
||||||
weights: None or array of shape [batch, length]
|
|
||||||
label_smoothing: label smoothing constant, used to determine the on and off values.
|
|
||||||
Returns:
|
|
||||||
Tuple of scalar loss and batch normalizing factor.
|
|
||||||
"""
|
|
||||||
if logits.ndim != targets.ndim + 1:
|
|
||||||
raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets")
|
|
||||||
|
|
||||||
vocab_size = logits.shape[-1]
|
|
||||||
confidence = 1.0 - label_smoothing
|
|
||||||
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
|
||||||
normalizing_constant = -(
|
|
||||||
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
|
||||||
)
|
|
||||||
soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence)
|
|
||||||
|
|
||||||
loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1)
|
|
||||||
loss = loss - normalizing_constant
|
|
||||||
|
|
||||||
if weights is not None:
|
|
||||||
loss = loss * weights
|
|
||||||
normalizing_factor = weights.sum()
|
|
||||||
else:
|
|
||||||
normalizing_factor = np.prod(targets.shape)
|
|
||||||
|
|
||||||
return loss.sum(), normalizing_factor
|
|
||||||
|
|
||||||
|
|
||||||
def training_step(optimizer, batch, dropout_rng):
|
|
||||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
|
||||||
|
|
||||||
def loss_fn(params):
|
|
||||||
targets = batch.pop("labels")
|
|
||||||
|
|
||||||
# Hide away tokens which doesn't participate in the optimization
|
|
||||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
|
||||||
|
|
||||||
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
|
||||||
loss, weight_sum = cross_entropy(logits, targets, token_mask)
|
|
||||||
return loss / weight_sum
|
|
||||||
|
|
||||||
step = optimizer.state.step
|
|
||||||
lr = lr_scheduler_fn(step)
|
|
||||||
grad_fn = jax.value_and_grad(loss_fn)
|
|
||||||
loss, grad = grad_fn(optimizer.target)
|
|
||||||
grad = jax.lax.pmean(grad, "batch")
|
|
||||||
optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
|
|
||||||
|
|
||||||
return loss, optimizer, new_dropout_rng
|
|
||||||
|
|
||||||
|
|
||||||
def eval_step(params, batch):
|
|
||||||
"""
|
|
||||||
Calculate evaluation metrics on a batch.
|
|
||||||
"""
|
|
||||||
targets = batch.pop("labels")
|
|
||||||
|
|
||||||
# Hide away tokens which doesn't participate in the optimization
|
|
||||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
|
||||||
logits = model(**batch, params=params, train=False)[0]
|
|
||||||
|
|
||||||
return compute_metrics(logits, targets, token_mask)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
||||||
nb_samples = len(samples_idx)
|
num_samples = len(samples_idx)
|
||||||
samples_to_remove = nb_samples % batch_size
|
samples_to_remove = num_samples % batch_size
|
||||||
|
|
||||||
if samples_to_remove != 0:
|
if samples_to_remove != 0:
|
||||||
samples_idx = samples_idx[:-samples_to_remove]
|
samples_idx = samples_idx[:-samples_to_remove]
|
||||||
sections_split = nb_samples // batch_size
|
sections_split = num_samples // batch_size
|
||||||
batch_idx = np.split(samples_idx, sections_split)
|
batch_idx = np.split(samples_idx, sections_split)
|
||||||
return batch_idx
|
return batch_idx
|
||||||
|
|
||||||
|
|
||||||
|
def write_metric(train_metrics, eval_metrics, train_time, step):
|
||||||
|
summary_writer.scalar("train_time", train_time, step)
|
||||||
|
|
||||||
|
train_metrics = get_metrics(train_metrics)
|
||||||
|
for key, vals in train_metrics.items():
|
||||||
|
tag = f"train_{key}"
|
||||||
|
for i, val in enumerate(vals):
|
||||||
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||||
|
|
||||||
|
for metric_name, value in eval_metrics.items():
|
||||||
|
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
@@ -486,6 +349,7 @@ if __name__ == "__main__":
|
|||||||
if data_args.dataset_name is not None:
|
if data_args.dataset_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||||
|
|
||||||
if "validation" not in datasets.keys():
|
if "validation" not in datasets.keys():
|
||||||
datasets["validation"] = load_dataset(
|
datasets["validation"] = load_dataset(
|
||||||
data_args.dataset_name,
|
data_args.dataset_name,
|
||||||
@@ -610,7 +474,6 @@ if __name__ == "__main__":
|
|||||||
#
|
#
|
||||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||||
|
|
||||||
tokenized_datasets = tokenized_datasets.map(
|
tokenized_datasets = tokenized_datasets.map(
|
||||||
group_texts,
|
group_texts,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -619,7 +482,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Enable tensorboard only on the master node
|
# Enable tensorboard only on the master node
|
||||||
if has_tensorboard and jax.host_id() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
@@ -632,58 +495,128 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||||
|
|
||||||
# Setup optimizer
|
|
||||||
optimizer = Adam(
|
|
||||||
learning_rate=training_args.learning_rate,
|
|
||||||
weight_decay=training_args.weight_decay,
|
|
||||||
beta1=training_args.adam_beta1,
|
|
||||||
beta2=training_args.adam_beta2,
|
|
||||||
).create(model.params)
|
|
||||||
|
|
||||||
# Create learning rate scheduler
|
|
||||||
# warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent.
|
|
||||||
lr_scheduler_fn = create_learning_rate_scheduler(
|
|
||||||
base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create parallel version of the training and evaluation steps
|
|
||||||
p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,))
|
|
||||||
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
|
||||||
|
|
||||||
# Replicate the optimizer on each device
|
|
||||||
optimizer = jax_utils.replicate(optimizer)
|
|
||||||
|
|
||||||
# Store some constant
|
# Store some constant
|
||||||
nb_epochs = int(training_args.num_train_epochs)
|
num_epochs = int(training_args.num_train_epochs)
|
||||||
batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||||
|
|
||||||
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
|
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
||||||
for epoch in epochs:
|
|
||||||
|
|
||||||
|
# Create learning rate schedule
|
||||||
|
warmup_fn = optax.linear_schedule(
|
||||||
|
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
||||||
|
)
|
||||||
|
decay_fn = optax.linear_schedule(
|
||||||
|
init_value=training_args.learning_rate,
|
||||||
|
end_value=0,
|
||||||
|
transition_steps=num_train_steps - training_args.warmup_steps,
|
||||||
|
)
|
||||||
|
linear_decay_lr_schedule_fn = optax.join_schedules(
|
||||||
|
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
||||||
|
)
|
||||||
|
|
||||||
|
# create adam optimizer
|
||||||
|
adamw = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=training_args.adam_beta1,
|
||||||
|
b2=training_args.adam_beta2,
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=training_args.weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup train state
|
||||||
|
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
||||||
|
|
||||||
|
# Define gradient update step fn
|
||||||
|
def train_step(state, batch, dropout_rng):
|
||||||
|
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||||
|
|
||||||
|
def loss_fn(params):
|
||||||
|
labels = batch.pop("labels")
|
||||||
|
|
||||||
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||||
|
|
||||||
|
# compute loss, ignore padded input tokens
|
||||||
|
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
||||||
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||||
|
|
||||||
|
# take average
|
||||||
|
loss = loss.sum() / label_mask.sum()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
grad_fn = jax.value_and_grad(loss_fn)
|
||||||
|
loss, grad = grad_fn(state.params)
|
||||||
|
grad = jax.lax.pmean(grad, "batch")
|
||||||
|
new_state = state.apply_gradients(grads=grad)
|
||||||
|
|
||||||
|
metrics = jax.lax.pmean(
|
||||||
|
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_state, metrics, new_dropout_rng
|
||||||
|
|
||||||
|
# Create parallel version of the train step
|
||||||
|
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
||||||
|
|
||||||
|
# Define eval fn
|
||||||
|
def eval_step(params, batch):
|
||||||
|
labels = batch.pop("labels")
|
||||||
|
|
||||||
|
logits = model(**batch, params=params, train=False)[0]
|
||||||
|
|
||||||
|
# compute loss, ignore padded input tokens
|
||||||
|
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
||||||
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||||
|
|
||||||
|
# compute accuracy
|
||||||
|
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
|
||||||
|
|
||||||
|
# summarize metrics
|
||||||
|
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
|
||||||
|
metrics = jax.lax.psum(metrics, axis_name="batch")
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
||||||
|
|
||||||
|
# Replicate the train state on each device
|
||||||
|
state = jax_utils.replicate(state)
|
||||||
|
|
||||||
|
train_metrics = []
|
||||||
|
train_time = 0
|
||||||
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
|
for epoch in epochs:
|
||||||
# ======================== Training ================================
|
# ======================== Training ================================
|
||||||
|
train_start = time.time()
|
||||||
|
|
||||||
# Create sampling rng
|
# Create sampling rng
|
||||||
rng, training_rng, eval_rng = jax.random.split(rng, 3)
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
nb_training_samples = len(tokenized_datasets["train"])
|
num_train_samples = len(tokenized_datasets["train"])
|
||||||
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
|
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
||||||
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||||
|
|
||||||
# Gather the indexes for creating the batch and do a training step
|
# Gather the indexes for creating the batch and do a training step
|
||||||
for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1):
|
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
||||||
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||||
|
|
||||||
# Model forward
|
# Model forward
|
||||||
model_inputs = common_utils.shard(model_inputs.data)
|
model_inputs = shard(model_inputs.data)
|
||||||
loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs)
|
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
||||||
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
epochs.write(f"Loss: {loss}")
|
train_time += time.time() - train_start
|
||||||
|
|
||||||
|
epochs.write(
|
||||||
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
||||||
|
)
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
nb_eval_samples = len(tokenized_datasets["validation"])
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
eval_samples_idx = jnp.arange(nb_eval_samples)
|
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
@@ -692,26 +625,27 @@ if __name__ == "__main__":
|
|||||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||||
|
|
||||||
# Model forward
|
# Model forward
|
||||||
model_inputs = common_utils.shard(model_inputs.data)
|
model_inputs = shard(model_inputs.data)
|
||||||
metrics = p_eval_step(optimizer.target, model_inputs)
|
metrics = p_eval_step(state.params, model_inputs)
|
||||||
eval_metrics.append(metrics)
|
eval_metrics.append(metrics)
|
||||||
|
|
||||||
eval_metrics_np = get_metrics(eval_metrics)
|
# normalize eval metrics
|
||||||
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_normalizer = eval_metrics_np.pop("normalizer")
|
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
||||||
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
|
eval_normalizer = eval_metrics.pop("normalizer")
|
||||||
|
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.desc = (
|
epochs.desc = (
|
||||||
f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})"
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save metrics
|
# Save metrics
|
||||||
if has_tensorboard and jax.host_id() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
for name, value in eval_summary.items():
|
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
||||||
summary_writer.scalar(name, value, epoch)
|
write_metric(train_metrics, eval_metrics, train_time, cur_step)
|
||||||
|
|
||||||
# save last checkpoint
|
# save last checkpoint
|
||||||
if jax.host_id() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
datasets >= 1.1.3
|
datasets >= 1.1.3
|
||||||
jax>=0.2.8
|
jax>=0.2.8
|
||||||
jaxlib>=0.1.59
|
jaxlib>=0.1.59
|
||||||
git+https://github.com/google/flax.git
|
flax>=0.3.4
|
||||||
git+https://github.com/deepmind/optax.git
|
git+https://github.com/deepmind/optax.git
|
||||||
|
|||||||
Reference in New Issue
Block a user