From f1c81d6b921fae7adf0f80be7e2567e92221ab1f Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 5 Jul 2021 18:23:03 +0530 Subject: [PATCH] [Flax] ViT training example (#12300) * begin script * clean example, add readme * update readme * remove decay mask * remove masking * update readme & make flake happy --- examples/flax/vision/README.md | 101 ++++ examples/flax/vision/requirements.txt | 8 + .../flax/vision/run_image_classification.py | 467 ++++++++++++++++++ 3 files changed, 576 insertions(+) create mode 100644 examples/flax/vision/README.md create mode 100644 examples/flax/vision/requirements.txt create mode 100644 examples/flax/vision/run_image_classification.py diff --git a/examples/flax/vision/README.md b/examples/flax/vision/README.md new file mode 100644 index 0000000000..19a213b838 --- /dev/null +++ b/examples/flax/vision/README.md @@ -0,0 +1,101 @@ + + +# Image Classification training examples + +The following example showcases how to train/fine-tune `ViT` for image-classification using the JAX/Flax backend. + +JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. +Models written in JAX/Flax are **immutable** and updated in a purely functional +way which enables simple and efficient model parallelism. + + +In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset. + +Let's start by creating a model repository to save the trained model and logs. +Here we call the model `"vit-base-patch16-imagenette"`, but you can change the model name as you like. + +You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that +you are logged in) or via the command line: + +``` +huggingface-cli repo create vit-base-patch16-imagenette +``` +Next we clone the model repository to add the tokenizer and model files. +``` +git clone https://huggingface.co//vit-base-patch16-imagenette +``` +To ensure that all tensorboard traces will be uploaded correctly, we need to +track them. You can run the following command inside your model repo to do so. + +``` +cd vit-base-patch16-imagenette +git lfs track "*tfevents*" +``` + +Great, we have set up our model repository. During training, we will automatically +push the training logs and model weights to the repo. + +Next, let's add a symbolic link to the `run_image_classification_flax.py`. + +```bash +export MODEL_DIR="./vit-base-patch16-imagenette +ln -s ~/transformers/examples/flax/summarization/run_image_classification_flax.py run_image_classification_flax.py +``` + +## Prepare the dataset + +We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute). + + +### Download and extract the data. + +```bash +wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz +tar -xvzf imagenette2.tgz +``` + +This will create a `imagenette2` dir with two subdirectories `train` and `val` each with multiple subdirectories per class. The training script expects the following directory structure + +```bash +root/dog/xxx.png +root/dog/xxy.png +root/dog/[...]/xxz.png + +root/cat/123.png +root/cat/nsdf3.png +root/cat/[...]/asd932_.png +``` + +## Train the model + +Next we can run the example script to fine-tune the model: + +```bash +python run_image_classification.py \ + --output_dir ${MODEL_DIR} \ + --model_name_or_path google/vit-base-patch16-224-in21k \ + --train_dir="imagenette2/train" \ + --validation_dir="imagenette2/val" \ + --num_train_epochs 5 \ + --learning_rate 1e-3 \ + --per_device_train_batch_size 128 --per_device_eval_batch_size 128 \ + --overwrite_output_dir \ + --preprocessing_num_workers 32 \ + --push_to_hub +``` + +This should finish in ~7mins with 99% validation accuracy. \ No newline at end of file diff --git a/examples/flax/vision/requirements.txt b/examples/flax/vision/requirements.txt new file mode 100644 index 0000000000..67881c95ce --- /dev/null +++ b/examples/flax/vision/requirements.txt @@ -0,0 +1,8 @@ +jax>=0.2.8 +jaxlib>=0.1.59 +flax>=0.3.4 +optax>=0.0.8 +-f https://download.pytorch.org/whl/torch_stable.html +torch==1.9.0+cpu +-f https://download.pytorch.org/whl/torch_stable.html +torchvision==0.10.0+cpu \ No newline at end of file diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py new file mode 100644 index 0000000000..e12a20aa27 --- /dev/null +++ b/examples/flax/vision/run_image_classification.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pre-training/Fine-tuning ViT for image classification . + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=vit +""" + +import logging +import os +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Optional + +# for dataset and preprocessing +import torch +import torchvision +import torchvision.transforms as transforms +from tqdm import tqdm + +import jax +import jax.numpy as jnp +import optax +import transformers +from flax import jax_utils +from flax.jax_utils import unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from transformers import ( + CONFIG_MAPPING, + FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + AutoConfig, + FlaxAutoModelForImageClassification, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + set_seed, +) + + +logger = logging.getLogger(__name__) + + +MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + dtype: Optional[str] = field( + default="float32", + metadata={ + "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + train_dir: str = field( + metadata={"help": "Path to the root training directory which contains one subdirectory per class."} + ) + validation_dir: str = field( + metadata={"help": "Path to the root validation directory which contains one subdirectory per class."}, + ) + image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."}) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + + +class TrainState(train_state.TrainState): + dropout_rng: jnp.ndarray + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # set seed for random transforms and torch dataloaders + set_seed(training_args.seed) + + # Initialize datasets and pre-processing transforms + # We use torchvision here for faster pre-processing + # Note that here we are using some default pre-processing, for maximum accuray + # one should tune this part and carefully select what transformations to use. + normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + train_dataset = torchvision.datasets.ImageFolder( + data_args.train_dir, + transforms.Compose( + [ + transforms.RandomResizedCrop(data_args.image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + + eval_dataset = torchvision.datasets.ImageFolder( + data_args.validation_dir, + transforms.Compose( + [ + transforms.Resize(data_args.image_size), + transforms.CenterCrop(data_args.image_size), + transforms.ToTensor(), + normalize, + ] + ), + ) + + # Load pretrained model and tokenizer + if model_args.config_name: + config = AutoConfig.from_pretrained( + model_args.config_name, + num_labels=len(train_dataset.classes), + image_size=data_args.image_size, + cache_dir=model_args.cache_dir, + ) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=len(train_dataset.classes), + image_size=data_args.image_size, + cache_dir=model_args.cache_dir, + ) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if model_args.model_name_or_path: + model = FlaxAutoModelForImageClassification.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + model = FlaxAutoModelForImageClassification.from_config( + config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + + # Store some constant + num_epochs = int(training_args.num_train_epochs) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + steps_per_epoch = len(train_dataset) // train_batch_size + total_train_steps = steps_per_epoch * num_epochs + + def collate_fn(examples): + pixel_values = torch.stack([example[0] for example in examples]) + labels = torch.tensor([example[1] for example in examples]) + + batch = {"pixel_values": pixel_values, "labels": labels} + batch = {k: v.numpy() for k, v in batch.items()} + + return batch + + # Create data loaders + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + num_workers=data_args.preprocessing_num_workers, + persistent_workers=True, + drop_last=True, + collate_fn=collate_fn, + ) + + eval_loader = torch.utils.data.DataLoader( + eval_dataset, + batch_size=eval_batch_size, + shuffle=False, + num_workers=data_args.preprocessing_num_workers, + persistent_workers=True, + drop_last=True, + collate_fn=collate_fn, + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + + # Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + len(train_dataset), + train_batch_size, + training_args.num_train_epochs, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # create adam optimizer + adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + ) + + # Setup train state + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) + + def loss_fn(logits, labels): + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) + return loss.mean() + + # Define gradient update step fn + def train_step(state, batch): + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss = loss_fn(logits, labels) + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + + new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics = jax.lax.pmean(metrics, axis_name="batch") + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + loss = loss_fn(logits, labels) + + # summarize metrics + accuracy = (jnp.argmax(logits, axis=-1) == labels).mean() + metrics = {"loss": loss, "accuracy": accuracy} + metrics = jax.lax.pmean(metrics, axis_name="batch") + return metrics + + # Create parallel version of the train and eval step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + p_eval_step = jax.pmap(eval_step, "batch") + + # Replicate the train state on each device + state = state.replicate() + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") + logger.info(f" Total optimization steps = {total_train_steps}") + + train_time = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + # ======================== Training ================================ + train_start = time.time() + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + train_metrics = [] + + steps_per_epoch = len(train_dataset) // train_batch_size + train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) + # train + for batch in train_loader: + batch = shard(batch) + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + + train_step_progress_bar.update(1) + + train_time += time.time() - train_start + + train_metric = unreplicate(train_metric) + + train_step_progress_bar.close() + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" + ) + + # ======================== Evaluating ============================== + eval_metrics = [] + eval_steps = len(eval_dataset) // eval_batch_size + eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) + for batch in eval_loader: + # Model forward + batch = shard(batch) + metrics = p_eval_step(state.params, batch) + eval_metrics.append(metrics) + + eval_step_progress_bar.update(1) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + + # Print metrics and update progress bar + eval_step_progress_bar.close() + desc = ( + f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | " + f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})" + ) + epochs.write(desc) + epochs.desc = desc + + # Save metrics + if has_tensorboard and jax.process_index() == 0: + cur_step = epoch * (len(train_dataset) // train_batch_size) + write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=f"Saving weights and logs of epoch {epoch+1}", + ) + + +if __name__ == "__main__": + main()