[Flax MLM] Refactor run mlm with optax (#11745)
* refactor * update * update * update * refactor run mlm * finalize * refactor more * fix typo * update * finish refactor * modify run mlm * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * small fixes * upload * upload * finish run mlm script Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
43891be19b
commit
00440e350f
129
examples/flax/language-modeling/README.md
Normal file
129
examples/flax/language-modeling/README.md
Normal file
@@ -0,0 +1,129 @@
|
||||
<!---
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Language model training examples
|
||||
|
||||
The following example showcases how to train a language model from scratch
|
||||
using the JAX/Flax backend.
|
||||
|
||||
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||
way which enables simple and efficient model parallelism.
|
||||
|
||||
## Masked language modeling
|
||||
|
||||
In the following, we demonstrate how to train a bi-directional transformer model
|
||||
using masked language modeling objective as introduced in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
|
||||
More specifically, we demonstrate how JAX/Flax can be leveraged
|
||||
to pre-train [**`roberta-base`**](https://huggingface.co/roberta-base)
|
||||
in Norwegian on a single TPUv3-8 pod.
|
||||
|
||||
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||
|
||||
Let's start by creating a folder to save the trained model and a symbolic link to the `run_mlm_flax.py` script.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./norwegian-roberta-base"
|
||||
mkdir -p ${MODEL_DIR}
|
||||
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
|
||||
```
|
||||
|
||||
### Train tokenizer
|
||||
|
||||
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
|
||||
The tokenizer is trained on the complete Norwegian dataset of OSCAR
|
||||
and consequently saved in `${MODEL_DIR}`
|
||||
This can take up to 10 minutes depending on your hardware ☕.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
|
||||
|
||||
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
|
||||
|
||||
# load dataset
|
||||
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
|
||||
|
||||
# Instantiate tokenizer
|
||||
tokenizer = ByteLevelBPETokenizer()
|
||||
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i: i + batch_size]["text"]
|
||||
|
||||
# Customized training
|
||||
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
||||
"<s>",
|
||||
"<pad>",
|
||||
"</s>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
])
|
||||
|
||||
# Save files to disk
|
||||
tokenizer.save(f"{model_dir}/tokenizer.json")
|
||||
```
|
||||
|
||||
### Create configuration
|
||||
|
||||
Next, we create the model's configuration file. This is as simple
|
||||
as loading and storing [`**roberta-base**`](https://huggingface.co/roberta-base)
|
||||
in the local model folder:
|
||||
|
||||
```python
|
||||
from transformers import RobertaConfig
|
||||
|
||||
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
|
||||
|
||||
config = RobertaConfig.from_pretrained("roberta-base")
|
||||
config.save_pretrained(model_dir)
|
||||
```
|
||||
|
||||
### Train model
|
||||
|
||||
Next we can run the example script to pretrain the model:
|
||||
|
||||
```bash
|
||||
./run_mlm_flax.py \
|
||||
--output_dir="./runs" \
|
||||
--model_type="roberta" \
|
||||
--config_name="${MODEL_DIR}" \
|
||||
--tokenizer_name="${MODEL_DIR}" \
|
||||
--dataset_name="oscar" \
|
||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||
--max_seq_length="128" \
|
||||
--weight_decay="0.01" \
|
||||
--per_device_train_batch_size="128" \
|
||||
--per_device_eval_batch_size="128" \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="1000" \
|
||||
--overwrite_output_dir \
|
||||
--pad_to_max_length \
|
||||
--num_train_epochs="18" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98"
|
||||
```
|
||||
|
||||
Training should converge at a loss and accuracy
|
||||
of 1.78 and 0.64 respectively after 18 epochs on a single TPUv3-8.
|
||||
This should take less than 18 hours.
|
||||
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg).
|
||||
|
||||
For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a
|
||||
look at [this TODO: (Patrick)]() google colab.
|
||||
|
||||
|
||||
## TODO(Patrick): Add comparison with PyTorch GPU/TPU
|
||||
4
examples/flax/language-modeling/requirements.txt
Normal file
4
examples/flax/language-modeling/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
datasets >= 1.1.3
|
||||
jax>=0.2.8
|
||||
jaxlib>=0.1.59
|
||||
flax>=0.3.4
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -23,6 +23,7 @@ https://huggingface.co/models?filter=masked-lm
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||
@@ -35,11 +36,10 @@ from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
from flax import jax_utils
|
||||
from flax.optim import Adam
|
||||
from flax.training import common_utils
|
||||
from flax.training.common_utils import get_metrics
|
||||
from jax.nn import log_softmax
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
@@ -269,167 +269,30 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
return inputs, labels
|
||||
|
||||
|
||||
def create_learning_rate_scheduler(
|
||||
factors="constant * linear_warmup * rsqrt_decay",
|
||||
base_learning_rate=0.5,
|
||||
warmup_steps=1000,
|
||||
decay_factor=0.5,
|
||||
steps_per_decay=20000,
|
||||
steps_per_cycle=100000,
|
||||
):
|
||||
"""Creates learning rate schedule.
|
||||
Interprets factors in the factors string which can consist of:
|
||||
* constant: interpreted as the constant value,
|
||||
* linear_warmup: interpreted as linear warmup until warmup_steps,
|
||||
* rsqrt_decay: divide by square root of max(step, warmup_steps)
|
||||
* rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
|
||||
* decay_every: Every k steps decay the learning rate by decay_factor.
|
||||
* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
|
||||
Args:
|
||||
factors: string, factors separated by "*" that defines the schedule.
|
||||
base_learning_rate: float, the starting constant for the lr schedule.
|
||||
warmup_steps: int, how many steps to warm up for in the warmup schedule.
|
||||
decay_factor: float, the amount to decay the learning rate by.
|
||||
steps_per_decay: int, how often to decay the learning rate.
|
||||
steps_per_cycle: int, steps per cycle when using cosine decay.
|
||||
Returns:
|
||||
a function learning_rate(step): float -> {"learning_rate": float}, the
|
||||
step-dependent lr.
|
||||
"""
|
||||
factors = [n.strip() for n in factors.split("*")]
|
||||
|
||||
def step_fn(step):
|
||||
"""Step to learning rate function."""
|
||||
ret = 1.0
|
||||
for name in factors:
|
||||
if name == "constant":
|
||||
ret *= base_learning_rate
|
||||
elif name == "linear_warmup":
|
||||
ret *= jnp.minimum(1.0, step / warmup_steps)
|
||||
elif name == "rsqrt_decay":
|
||||
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
|
||||
elif name == "rsqrt_normalized_decay":
|
||||
ret *= jnp.sqrt(warmup_steps)
|
||||
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
|
||||
elif name == "decay_every":
|
||||
ret *= decay_factor ** (step // steps_per_decay)
|
||||
elif name == "cosine_decay":
|
||||
progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle))
|
||||
ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
|
||||
else:
|
||||
raise ValueError(f"Unknown factor {name}.")
|
||||
return jnp.asarray(ret, dtype=jnp.float32)
|
||||
|
||||
return step_fn
|
||||
|
||||
|
||||
def compute_metrics(logits, labels, weights, label_smoothing=0.0):
|
||||
"""Compute summary metrics."""
|
||||
loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing)
|
||||
acc, _ = accuracy(logits, labels, weights)
|
||||
metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer}
|
||||
metrics = jax.lax.psum(metrics, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
|
||||
def accuracy(logits, targets, weights=None):
|
||||
"""Compute weighted accuracy for log probs and targets.
|
||||
Args:
|
||||
logits: [batch, length, num_classes] float array.
|
||||
targets: categorical targets [batch, length] int array.
|
||||
weights: None or array of shape [batch, length]
|
||||
Returns:
|
||||
Tuple of scalar loss and batch normalizing factor.
|
||||
"""
|
||||
if logits.ndim != targets.ndim + 1:
|
||||
raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets")
|
||||
|
||||
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
|
||||
loss *= weights
|
||||
|
||||
return loss.sum(), weights.sum()
|
||||
|
||||
|
||||
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0):
|
||||
"""Compute cross entropy and entropy for log probs and targets.
|
||||
Args:
|
||||
logits: [batch, length, num_classes] float array.
|
||||
targets: categorical targets [batch, length] int array.
|
||||
weights: None or array of shape [batch, length]
|
||||
label_smoothing: label smoothing constant, used to determine the on and off values.
|
||||
Returns:
|
||||
Tuple of scalar loss and batch normalizing factor.
|
||||
"""
|
||||
if logits.ndim != targets.ndim + 1:
|
||||
raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets")
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
confidence = 1.0 - label_smoothing
|
||||
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||
normalizing_constant = -(
|
||||
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||
)
|
||||
soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||
|
||||
loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1)
|
||||
loss = loss - normalizing_constant
|
||||
|
||||
if weights is not None:
|
||||
loss = loss * weights
|
||||
normalizing_factor = weights.sum()
|
||||
else:
|
||||
normalizing_factor = np.prod(targets.shape)
|
||||
|
||||
return loss.sum(), normalizing_factor
|
||||
|
||||
|
||||
def training_step(optimizer, batch, dropout_rng):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
|
||||
def loss_fn(params):
|
||||
targets = batch.pop("labels")
|
||||
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
|
||||
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss, weight_sum = cross_entropy(logits, targets, token_mask)
|
||||
return loss / weight_sum
|
||||
|
||||
step = optimizer.state.step
|
||||
lr = lr_scheduler_fn(step)
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grad = grad_fn(optimizer.target)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
|
||||
|
||||
return loss, optimizer, new_dropout_rng
|
||||
|
||||
|
||||
def eval_step(params, batch):
|
||||
"""
|
||||
Calculate evaluation metrics on a batch.
|
||||
"""
|
||||
targets = batch.pop("labels")
|
||||
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
return compute_metrics(logits, targets, token_mask)
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
||||
nb_samples = len(samples_idx)
|
||||
samples_to_remove = nb_samples % batch_size
|
||||
num_samples = len(samples_idx)
|
||||
samples_to_remove = num_samples % batch_size
|
||||
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = nb_samples // batch_size
|
||||
sections_split = num_samples // batch_size
|
||||
batch_idx = np.split(samples_idx, sections_split)
|
||||
return batch_idx
|
||||
|
||||
|
||||
def write_metric(train_metrics, eval_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
train_metrics = get_metrics(train_metrics)
|
||||
for key, vals in train_metrics.items():
|
||||
tag = f"train_{key}"
|
||||
for i, val in enumerate(vals):
|
||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||
|
||||
for metric_name, value in eval_metrics.items():
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
@@ -486,6 +349,7 @@ if __name__ == "__main__":
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||
|
||||
if "validation" not in datasets.keys():
|
||||
datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
@@ -610,7 +474,6 @@ if __name__ == "__main__":
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
tokenized_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
@@ -619,7 +482,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
||||
|
||||
# Data collator
|
||||
@@ -632,58 +495,128 @@ if __name__ == "__main__":
|
||||
|
||||
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = Adam(
|
||||
learning_rate=training_args.learning_rate,
|
||||
weight_decay=training_args.weight_decay,
|
||||
beta1=training_args.adam_beta1,
|
||||
beta2=training_args.adam_beta2,
|
||||
).create(model.params)
|
||||
|
||||
# Create learning rate scheduler
|
||||
# warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent.
|
||||
lr_scheduler_fn = create_learning_rate_scheduler(
|
||||
base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1)
|
||||
)
|
||||
|
||||
# Create parallel version of the training and evaluation steps
|
||||
p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,))
|
||||
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Replicate the optimizer on each device
|
||||
optimizer = jax_utils.replicate(optimizer)
|
||||
|
||||
# Store some constant
|
||||
nb_epochs = int(training_args.num_train_epochs)
|
||||
batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
|
||||
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
||||
|
||||
# Create learning rate schedule
|
||||
warmup_fn = optax.linear_schedule(
|
||||
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
||||
)
|
||||
decay_fn = optax.linear_schedule(
|
||||
init_value=training_args.learning_rate,
|
||||
end_value=0,
|
||||
transition_steps=num_train_steps - training_args.warmup_steps,
|
||||
)
|
||||
linear_decay_lr_schedule_fn = optax.join_schedules(
|
||||
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
||||
)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
b1=training_args.adam_beta1,
|
||||
b2=training_args.adam_beta2,
|
||||
eps=1e-8,
|
||||
weight_decay=training_args.weight_decay,
|
||||
)
|
||||
|
||||
# Setup train state
|
||||
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch, dropout_rng):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
|
||||
def loss_fn(params):
|
||||
labels = batch.pop("labels")
|
||||
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
|
||||
# compute loss, ignore padded input tokens
|
||||
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||
|
||||
# take average
|
||||
loss = loss.sum() / label_mask.sum()
|
||||
|
||||
return loss
|
||||
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
new_state = state.apply_gradients(grads=grad)
|
||||
|
||||
metrics = jax.lax.pmean(
|
||||
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
||||
)
|
||||
|
||||
return new_state, metrics, new_dropout_rng
|
||||
|
||||
# Create parallel version of the train step
|
||||
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch):
|
||||
labels = batch.pop("labels")
|
||||
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
# compute loss, ignore padded input tokens
|
||||
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||
|
||||
# compute accuracy
|
||||
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
|
||||
|
||||
# summarize metrics
|
||||
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
|
||||
metrics = jax.lax.psum(metrics, axis_name="batch")
|
||||
|
||||
return metrics
|
||||
|
||||
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = jax_utils.replicate(state)
|
||||
|
||||
train_metrics = []
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
|
||||
# Create sampling rng
|
||||
rng, training_rng, eval_rng = jax.random.split(rng, 3)
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
|
||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||
nb_training_samples = len(tokenized_datasets["train"])
|
||||
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
|
||||
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
|
||||
num_train_samples = len(tokenized_datasets["train"])
|
||||
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
||||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||
|
||||
# Gather the indexes for creating the batch and do a training step
|
||||
for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1):
|
||||
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
||||
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = common_utils.shard(model_inputs.data)
|
||||
loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs)
|
||||
model_inputs = shard(model_inputs.data)
|
||||
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
epochs.write(f"Loss: {loss}")
|
||||
train_time += time.time() - train_start
|
||||
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
||||
)
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
nb_eval_samples = len(tokenized_datasets["validation"])
|
||||
eval_samples_idx = jnp.arange(nb_eval_samples)
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
@@ -692,26 +625,27 @@ if __name__ == "__main__":
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = common_utils.shard(model_inputs.data)
|
||||
metrics = p_eval_step(optimizer.target, model_inputs)
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
eval_metrics_np = get_metrics(eval_metrics)
|
||||
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
|
||||
eval_normalizer = eval_metrics_np.pop("normalizer")
|
||||
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
||||
eval_normalizer = eval_metrics.pop("normalizer")
|
||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
|
||||
# Update progress bar
|
||||
epochs.desc = (
|
||||
f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})"
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
||||
)
|
||||
|
||||
# Save metrics
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
for name, value in eval_summary.items():
|
||||
summary_writer.scalar(name, value, epoch)
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
||||
write_metric(train_metrics, eval_metrics, train_time, cur_step)
|
||||
|
||||
# save last checkpoint
|
||||
if jax.host_id() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
# save last checkpoint
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
datasets >= 1.1.3
|
||||
jax>=0.2.8
|
||||
jaxlib>=0.1.59
|
||||
git+https://github.com/google/flax.git
|
||||
flax>=0.3.4
|
||||
git+https://github.com/deepmind/optax.git
|
||||
|
||||
Reference in New Issue
Block a user