From 00440e350f58e33435f823ec8a940bd3861fe7ba Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 May 2021 12:00:58 +0100 Subject: [PATCH] [Flax MLM] Refactor run mlm with optax (#11745) * refactor * update * update * update * refactor run mlm * finalize * refactor more * fix typo * update * finish refactor * modify run mlm * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * small fixes * upload * upload * finish run mlm script Co-authored-by: Patrick von Platen --- examples/flax/language-modeling/README.md | 129 +++++++ .../flax/language-modeling/requirements.txt | 4 + .../flax/language-modeling/run_mlm_flax.py | 352 +++++++----------- .../flax/text-classification/requirements.txt | 2 +- 4 files changed, 277 insertions(+), 210 deletions(-) create mode 100644 examples/flax/language-modeling/README.md create mode 100644 examples/flax/language-modeling/requirements.txt diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md new file mode 100644 index 0000000000..9c3510ca98 --- /dev/null +++ b/examples/flax/language-modeling/README.md @@ -0,0 +1,129 @@ + + +# Language model training examples + +The following example showcases how to train a language model from scratch +using the JAX/Flax backend. + +JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. +Models written in JAX/Flax are **immutable** and updated in a purely functional +way which enables simple and efficient model parallelism. + +## Masked language modeling + +In the following, we demonstrate how to train a bi-directional transformer model +using masked language modeling objective as introduced in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805). +More specifically, we demonstrate how JAX/Flax can be leveraged +to pre-train [**`roberta-base`**](https://huggingface.co/roberta-base) +in Norwegian on a single TPUv3-8 pod. + +The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. + +Let's start by creating a folder to save the trained model and a symbolic link to the `run_mlm_flax.py` script. + +```bash +export MODEL_DIR="./norwegian-roberta-base" +mkdir -p ${MODEL_DIR} +ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py +``` + +### Train tokenizer + +In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**. +The tokenizer is trained on the complete Norwegian dataset of OSCAR +and consequently saved in `${MODEL_DIR}` +This can take up to 10 minutes depending on your hardware ☕. + +```python +from datasets import load_dataset +from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer + +model_dir = "./norwegian-roberta-base" # ${MODEL_DIR} + +# load dataset +dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train") + +# Instantiate tokenizer +tokenizer = ByteLevelBPETokenizer() + +def batch_iterator(batch_size=1000): + for i in range(0, len(dataset), batch_size): + yield dataset[i: i + batch_size]["text"] + +# Customized training +tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[ + "", + "", + "", + "", + "", +]) + +# Save files to disk +tokenizer.save(f"{model_dir}/tokenizer.json") +``` + +### Create configuration + +Next, we create the model's configuration file. This is as simple +as loading and storing [`**roberta-base**`](https://huggingface.co/roberta-base) +in the local model folder: + +```python +from transformers import RobertaConfig + +model_dir = "./norwegian-roberta-base" # ${MODEL_DIR} + +config = RobertaConfig.from_pretrained("roberta-base") +config.save_pretrained(model_dir) +``` + +### Train model + +Next we can run the example script to pretrain the model: + +```bash +./run_mlm_flax.py \ + --output_dir="./runs" \ + --model_type="roberta" \ + --config_name="${MODEL_DIR}" \ + --tokenizer_name="${MODEL_DIR}" \ + --dataset_name="oscar" \ + --dataset_config_name="unshuffled_deduplicated_no" \ + --max_seq_length="128" \ + --weight_decay="0.01" \ + --per_device_train_batch_size="128" \ + --per_device_eval_batch_size="128" \ + --learning_rate="3e-4" \ + --warmup_steps="1000" \ + --overwrite_output_dir \ + --pad_to_max_length \ + --num_train_epochs="18" \ + --adam_beta1="0.9" \ + --adam_beta2="0.98" +``` + +Training should converge at a loss and accuracy +of 1.78 and 0.64 respectively after 18 epochs on a single TPUv3-8. +This should take less than 18 hours. +Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg). + +For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a +look at [this TODO: (Patrick)]() google colab. + + +## TODO(Patrick): Add comparison with PyTorch GPU/TPU diff --git a/examples/flax/language-modeling/requirements.txt b/examples/flax/language-modeling/requirements.txt new file mode 100644 index 0000000000..7d4d161529 --- /dev/null +++ b/examples/flax/language-modeling/requirements.txt @@ -0,0 +1,4 @@ +datasets >= 1.1.3 +jax>=0.2.8 +jaxlib>=0.1.59 +flax>=0.3.4 diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 37fb7b585b..09885524d2 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # coding=utf-8 -# Copyright 2020 The HuggingFace Team All rights reserved. +# Copyright 2021 The HuggingFace Team All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ https://huggingface.co/models?filter=masked-lm import logging import os import sys +import time from dataclasses import dataclass, field # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. @@ -35,11 +36,10 @@ from tqdm import tqdm import jax import jax.numpy as jnp +import optax from flax import jax_utils -from flax.optim import Adam -from flax.training import common_utils -from flax.training.common_utils import get_metrics -from jax.nn import log_softmax +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard from transformers import ( CONFIG_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, @@ -269,167 +269,30 @@ class FlaxDataCollatorForLanguageModeling: return inputs, labels -def create_learning_rate_scheduler( - factors="constant * linear_warmup * rsqrt_decay", - base_learning_rate=0.5, - warmup_steps=1000, - decay_factor=0.5, - steps_per_decay=20000, - steps_per_cycle=100000, -): - """Creates learning rate schedule. - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. - Args: - factors: string, factors separated by "*" that defines the schedule. - base_learning_rate: float, the starting constant for the lr schedule. - warmup_steps: int, how many steps to warm up for in the warmup schedule. - decay_factor: float, the amount to decay the learning rate by. - steps_per_decay: int, how often to decay the learning rate. - steps_per_cycle: int, steps per cycle when using cosine decay. - Returns: - a function learning_rate(step): float -> {"learning_rate": float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split("*")] - - def step_fn(step): - """Step to learning rate function.""" - ret = 1.0 - for name in factors: - if name == "constant": - ret *= base_learning_rate - elif name == "linear_warmup": - ret *= jnp.minimum(1.0, step / warmup_steps) - elif name == "rsqrt_decay": - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == "rsqrt_normalized_decay": - ret *= jnp.sqrt(warmup_steps) - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == "decay_every": - ret *= decay_factor ** (step // steps_per_decay) - elif name == "cosine_decay": - progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) - ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) - else: - raise ValueError(f"Unknown factor {name}.") - return jnp.asarray(ret, dtype=jnp.float32) - - return step_fn - - -def compute_metrics(logits, labels, weights, label_smoothing=0.0): - """Compute summary metrics.""" - loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing) - acc, _ = accuracy(logits, labels, weights) - metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer} - metrics = jax.lax.psum(metrics, axis_name="batch") - return metrics - - -def accuracy(logits, targets, weights=None): - """Compute weighted accuracy for log probs and targets. - Args: - logits: [batch, length, num_classes] float array. - targets: categorical targets [batch, length] int array. - weights: None or array of shape [batch, length] - Returns: - Tuple of scalar loss and batch normalizing factor. - """ - if logits.ndim != targets.ndim + 1: - raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets") - - loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) - loss *= weights - - return loss.sum(), weights.sum() - - -def cross_entropy(logits, targets, weights=None, label_smoothing=0.0): - """Compute cross entropy and entropy for log probs and targets. - Args: - logits: [batch, length, num_classes] float array. - targets: categorical targets [batch, length] int array. - weights: None or array of shape [batch, length] - label_smoothing: label smoothing constant, used to determine the on and off values. - Returns: - Tuple of scalar loss and batch normalizing factor. - """ - if logits.ndim != targets.ndim + 1: - raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets") - - vocab_size = logits.shape[-1] - confidence = 1.0 - label_smoothing - low_confidence = (1.0 - confidence) / (vocab_size - 1) - normalizing_constant = -( - confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) - ) - soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence) - - loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1) - loss = loss - normalizing_constant - - if weights is not None: - loss = loss * weights - normalizing_factor = weights.sum() - else: - normalizing_factor = np.prod(targets.shape) - - return loss.sum(), normalizing_factor - - -def training_step(optimizer, batch, dropout_rng): - dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) - - def loss_fn(params): - targets = batch.pop("labels") - - # Hide away tokens which doesn't participate in the optimization - token_mask = jnp.where(targets > 0, 1.0, 0.0) - - logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] - loss, weight_sum = cross_entropy(logits, targets, token_mask) - return loss / weight_sum - - step = optimizer.state.step - lr = lr_scheduler_fn(step) - grad_fn = jax.value_and_grad(loss_fn) - loss, grad = grad_fn(optimizer.target) - grad = jax.lax.pmean(grad, "batch") - optimizer = optimizer.apply_gradient(grad, learning_rate=lr) - - return loss, optimizer, new_dropout_rng - - -def eval_step(params, batch): - """ - Calculate evaluation metrics on a batch. - """ - targets = batch.pop("labels") - - # Hide away tokens which doesn't participate in the optimization - token_mask = jnp.where(targets > 0, 1.0, 0.0) - logits = model(**batch, params=params, train=False)[0] - - return compute_metrics(logits, targets, token_mask) - - def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: - nb_samples = len(samples_idx) - samples_to_remove = nb_samples % batch_size + num_samples = len(samples_idx) + samples_to_remove = num_samples % batch_size if samples_to_remove != 0: samples_idx = samples_idx[:-samples_to_remove] - sections_split = nb_samples // batch_size + sections_split = num_samples // batch_size batch_idx = np.split(samples_idx, sections_split) return batch_idx +def write_metric(train_metrics, eval_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + if __name__ == "__main__": # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -486,6 +349,7 @@ if __name__ == "__main__": if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) + if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, @@ -610,7 +474,6 @@ if __name__ == "__main__": # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, @@ -619,7 +482,7 @@ if __name__ == "__main__": ) # Enable tensorboard only on the master node - if has_tensorboard and jax.host_id() == 0: + if has_tensorboard and jax.process_index() == 0: summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) # Data collator @@ -632,58 +495,128 @@ if __name__ == "__main__": model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) - # Setup optimizer - optimizer = Adam( - learning_rate=training_args.learning_rate, - weight_decay=training_args.weight_decay, - beta1=training_args.adam_beta1, - beta2=training_args.adam_beta2, - ).create(model.params) - - # Create learning rate scheduler - # warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent. - lr_scheduler_fn = create_learning_rate_scheduler( - base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1) - ) - - # Create parallel version of the training and evaluation steps - p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,)) - p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) - - # Replicate the optimizer on each device - optimizer = jax_utils.replicate(optimizer) - # Store some constant - nb_epochs = int(training_args.num_train_epochs) - batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + num_epochs = int(training_args.num_train_epochs) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() - epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0) - for epoch in epochs: + num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs + # Create learning rate schedule + warmup_fn = optax.linear_schedule( + init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps + ) + decay_fn = optax.linear_schedule( + init_value=training_args.learning_rate, + end_value=0, + transition_steps=num_train_steps - training_args.warmup_steps, + ) + linear_decay_lr_schedule_fn = optax.join_schedules( + schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps] + ) + + # create adam optimizer + adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=1e-8, + weight_decay=training_args.weight_decay, + ) + + # Setup train state + state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) + + # Define gradient update step fn + def train_step(state, batch, dropout_rng): + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) + + def loss_fn(params): + labels = batch.pop("labels") + + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + + # compute loss, ignore padded input tokens + label_mask = jnp.where(labels > 0, 1.0, 0.0) + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask + + # take average + loss = loss.sum() / label_mask.sum() + + return loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + new_state = state.apply_gradients(grads=grad) + + metrics = jax.lax.pmean( + {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" + ) + + return new_state, metrics, new_dropout_rng + + # Create parallel version of the train step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + + logits = model(**batch, params=params, train=False)[0] + + # compute loss, ignore padded input tokens + label_mask = jnp.where(labels > 0, 1.0, 0.0) + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask + + # compute accuracy + accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask + + # summarize metrics + metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()} + metrics = jax.lax.psum(metrics, axis_name="batch") + + return metrics + + p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) + + # Replicate the train state on each device + state = jax_utils.replicate(state) + + train_metrics = [] + train_time = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: # ======================== Training ================================ + train_start = time.time() + # Create sampling rng - rng, training_rng, eval_rng = jax.random.split(rng, 3) + rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset - nb_training_samples = len(tokenized_datasets["train"]) - training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples)) - training_batch_idx = generate_batch_splits(training_samples_idx, batch_size) + num_train_samples = len(tokenized_datasets["train"]) + train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) + train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step - for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1): + for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward - model_inputs = common_utils.shard(model_inputs.data) - loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs) + model_inputs = shard(model_inputs.data) + state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) + train_metrics.append(train_metric) - epochs.write(f"Loss: {loss}") + train_time += time.time() - train_start + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" + ) # ======================== Evaluating ============================== - nb_eval_samples = len(tokenized_datasets["validation"]) - eval_samples_idx = jnp.arange(nb_eval_samples) + num_eval_samples = len(tokenized_datasets["validation"]) + eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] @@ -692,26 +625,27 @@ if __name__ == "__main__": model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward - model_inputs = common_utils.shard(model_inputs.data) - metrics = p_eval_step(optimizer.target, model_inputs) + model_inputs = shard(model_inputs.data) + metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) - eval_metrics_np = get_metrics(eval_metrics) - eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np) - eval_normalizer = eval_metrics_np.pop("normalizer") - eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np) + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(jnp.sum, eval_metrics) + eval_normalizer = eval_metrics.pop("normalizer") + eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) # Update progress bar epochs.desc = ( - f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})" + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" ) # Save metrics - if has_tensorboard and jax.host_id() == 0: - for name, value in eval_summary.items(): - summary_writer.scalar(name, value, epoch) + if has_tensorboard and jax.process_index() == 0: + cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) + write_metric(train_metrics, eval_metrics, train_time, cur_step) - # save last checkpoint - if jax.host_id() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target)) - model.save_pretrained(training_args.output_dir, params=params) + # save last checkpoint + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) diff --git a/examples/flax/text-classification/requirements.txt b/examples/flax/text-classification/requirements.txt index f428e9cccb..112efe6897 100644 --- a/examples/flax/text-classification/requirements.txt +++ b/examples/flax/text-classification/requirements.txt @@ -1,5 +1,5 @@ datasets >= 1.1.3 jax>=0.2.8 jaxlib>=0.1.59 -git+https://github.com/google/flax.git +flax>=0.3.4 git+https://github.com/deepmind/optax.git