From d1f5ca1afd03e38f45062b2a06f1846a7c290da4 Mon Sep 17 00:00:00 2001 From: Kamal Raj Date: Wed, 19 Jan 2022 16:34:51 +0530 Subject: [PATCH] [FLAX] glue training example refactor (#13815) * refactor run_flax_glue.py * updated readme * rm unused import and args typo fix * refactor * make consistent arg name across task * has_tensorboard check * argparse -> argument dataclasses * refactor according to review * fix --- examples/flax/test_examples.py | 6 +- examples/flax/text-classification/README.md | 5 +- .../flax/text-classification/run_flax_glue.py | 391 +++++++++++------- 3 files changed, 238 insertions(+), 164 deletions(-) diff --git a/examples/flax/test_examples.py b/examples/flax/test_examples.py index d57c0d36c6..98c29a821c 100644 --- a/examples/flax/test_examples.py +++ b/examples/flax/test_examples.py @@ -85,10 +85,10 @@ class ExamplesTests(TestCasePlus): --per_device_train_batch_size=2 --per_device_eval_batch_size=1 --learning_rate=1e-4 - --max_train_steps=10 - --num_warmup_steps=2 + --eval_steps=2 + --warmup_steps=2 --seed=42 - --max_length=128 + --max_seq_length=128 """.split() with patch.object(sys, "argv", testargs): diff --git a/examples/flax/text-classification/README.md b/examples/flax/text-classification/README.md index bf4c4c79cc..f7e27b4245 100644 --- a/examples/flax/text-classification/README.md +++ b/examples/flax/text-classification/README.md @@ -33,15 +33,16 @@ export TASK_NAME=mrpc python run_flax_glue.py \ --model_name_or_path bert-base-cased \ --task_name ${TASK_NAME} \ - --max_length 128 \ + --max_seq_length 128 \ --learning_rate 2e-5 \ --num_train_epochs 3 \ --per_device_train_batch_size 4 \ + --eval_steps 100 \ --output_dir ./$TASK_NAME/ \ --push_to_hub ``` -where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli. +where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli. Using the command above, the script will train for 3 epochs and run eval after each epoch. Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`. diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index f27b7cd05c..96b5bbdd3d 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -14,18 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE.""" -import argparse import json import logging import os import random +import sys import time +from dataclasses import dataclass, field from itertools import chain from pathlib import Path -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import datasets +import numpy as np from datasets import load_dataset, load_metric +from tqdm import tqdm import jax import jax.numpy as jnp @@ -40,13 +43,18 @@ from transformers import ( AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, + HfArgumentParser, PretrainedConfig, + TrainingArguments, is_tensorboard_available, ) from transformers.file_utils import get_full_repo_name +from transformers.utils import check_min_version logger = logging.getLogger(__name__) +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.16.0.dev0") Array = Any Dataset = datasets.arrow_dataset.Dataset @@ -66,101 +74,118 @@ task_to_keys = { } -def parse_args(): - parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") - parser.add_argument( - "--task_name", - type=str, +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + use_slow_tokenizer: Optional[bool] = field( + default=False, + metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."}, + ) + cache_dir: Optional[str] = field( default=None, - help="The name of the glue task to train on.", - choices=list(task_to_keys.keys()), + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) - parser.add_argument( - "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) - parser.add_argument( - "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, ) - parser.add_argument( - "--max_length", - type=int, - default=128, - help=( - "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," - " sequences shorter will be padded." - ), + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + task_name: Optional[str] = field( + default=None, metadata={"help": f"The name of the glue task to train on. choices {list(task_to_keys.keys())}"} ) - parser.add_argument( - "--model_name_or_path", - type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=True, + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) - parser.add_argument( - "--use_slow_tokenizer", - action="store_true", - help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a csv or JSON file)."} ) - parser.add_argument( - "--per_device_train_batch_size", - type=int, - default=8, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--per_device_eval_batch_size", - type=int, - default=8, - help="Batch size (per device) for the evaluation dataloader.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") - parser.add_argument( - "--max_train_steps", - type=int, + validation_file: Optional[str] = field( default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, ) - parser.add_argument( - "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, ) - parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") - parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.") - parser.add_argument( - "--push_to_hub", - action="store_true", - help="If passed, model checkpoints and tensorboard logs will be pushed to the hub", + text_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} ) - parser.add_argument( - "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + label_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: int = field( + default=None, + metadata={ + "help": "The maximum total input sequence length after tokenization. If set, sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, ) - parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") - args = parser.parse_args() - # Sanity checks - if args.task_name is None and args.train_file is None and args.validation_file is None: - raise ValueError("Need either a task name or a training/validation file.") - else: - if args.train_file is not None: - extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." - if args.validation_file is not None: - extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - - if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." - - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - return args + def __post_init__(self): + if self.task_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + self.task_name = self.task_name.lower() if type(self.task_name) == str else self.task_name def create_train_state( @@ -249,7 +274,7 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): for perm in perms: batch = dataset[perm] - batch = {k: jnp.array(v) for k, v in batch.items()} + batch = {k: np.array(v) for k, v in batch.items()} batch = shard(batch) yield batch @@ -259,14 +284,20 @@ def glue_eval_data_collator(dataset: Dataset, batch_size: int): """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" for i in range(len(dataset) // batch_size): batch = dataset[i * batch_size : (i + 1) * batch_size] - batch = {k: jnp.array(v) for k, v in batch.items()} + batch = {k: np.array(v) for k, v in batch.items()} batch = shard(batch) yield batch def main(): - args = parse_args() + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -284,12 +315,14 @@ def main(): transformers.utils.logging.set_verbosity_error() # Handle the repository creation - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token) + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). @@ -303,24 +336,24 @@ def main(): # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. - if args.task_name is not None: + if data_args.task_name is not None: # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset("glue", args.task_name) + raw_datasets = load_dataset("glue", data_args.task_name) else: # Loading the dataset from local csv or json file. data_files = {} - if args.train_file is not None: - data_files["train"] = args.train_file - if args.validation_file is not None: - data_files["validation"] = args.validation_file - extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. # Labels - if args.task_name is not None: - is_regression = args.task_name == "stsb" + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" if not is_regression: label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) @@ -339,13 +372,17 @@ def main(): num_labels = len(label_list) # Load pretrained model and tokenizer - config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) - model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config) + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, num_labels=num_labels, finetuning_task=data_args.task_name + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, use_fast=not model_args.use_slow_tokenizer + ) + model = FlaxAutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config) # Preprocessing the datasets - if args.task_name is not None: - sentence1_key, sentence2_key = task_to_keys[args.task_name] + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] else: # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] @@ -361,7 +398,7 @@ def main(): label_to_id = None if ( model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id - and args.task_name is not None + and data_args.task_name is not None and not is_regression ): # Some have all caps in their config, some don't. @@ -378,7 +415,7 @@ def main(): f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) - elif args.task_name is None: + elif data_args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} def preprocess_function(examples): @@ -386,7 +423,7 @@ def main(): texts = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) - result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True) + result = tokenizer(*texts, padding="max_length", max_length=data_args.max_seq_length, truncation=True) if "label" in examples: if label_to_id is not None: @@ -402,7 +439,7 @@ def main(): ) train_dataset = processed_datasets["train"] - eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] + eval_dataset = processed_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): @@ -414,8 +451,8 @@ def main(): try: from flax.metrics.tensorboard import SummaryWriter - summary_writer = SummaryWriter(args.output_dir) - summary_writer.hparams(vars(args)) + summary_writer = SummaryWriter(training_args.output_dir) + summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) except ImportError as ie: has_tensorboard = False logger.warning( @@ -427,7 +464,7 @@ def main(): "Please run pip install tensorboard to enable." ) - def write_metric(train_metrics, eval_metrics, train_time, step): + def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) @@ -436,22 +473,27 @@ def main(): for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) + def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) - num_epochs = int(args.num_train_epochs) - rng = jax.random.PRNGKey(args.seed) + num_epochs = int(training_args.num_train_epochs) + rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - train_batch_size = args.per_device_train_batch_size * jax.local_device_count() - eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count() + train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() + eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() learning_rate_fn = create_learning_rate_fn( - len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate + len(train_dataset), + train_batch_size, + training_args.num_train_epochs, + training_args.warmup_steps, + training_args.learning_rate, ) state = create_train_state( - model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay + model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=training_args.weight_decay ) # define step functions @@ -482,8 +524,8 @@ def main(): p_eval_step = jax.pmap(eval_step, axis_name="batch") - if args.task_name is not None: - metric = load_metric("glue", args.task_name) + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) else: metric = load_metric("accuracy") @@ -493,63 +535,94 @@ def main(): # make sure weights are replicated on each device state = replicate(state) - for epoch in range(1, num_epochs + 1): - logger.info(f"Epoch {epoch}") - logger.info(" Training...") + steps_per_epoch = len(train_dataset) // train_batch_size + total_steps = steps_per_epoch * num_epochs + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0) + for epoch in epochs: train_start = time.time() train_metrics = [] + + # Create sampling rng rng, input_rng = jax.random.split(rng) # train - for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): - state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs) - train_metrics.append(metrics) - train_time += time.time() - train_start - logger.info(f" Done! Training metrics: {unreplicate(metrics)}") + train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size) + for step, batch in enumerate( + tqdm( + train_loader, + total=steps_per_epoch, + desc="Training...", + position=1, + ), + ): + state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) + train_metrics.append(train_metric) - logger.info(" Evaluating...") + cur_step = (epoch * steps_per_epoch) + (step + 1) - # evaluate - for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): - labels = batch.pop("labels") - predictions = p_eval_step(state, batch) - metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) + if cur_step % training_args.logging_steps == 0 and cur_step > 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + if has_tensorboard and jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) - # evaluate also on leftover examples (not divisible by batch_size) - num_leftover_samples = len(eval_dataset) % eval_batch_size + epochs.write( + f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" + ) - # make sure leftover batch is evaluated on one device - if num_leftover_samples > 0 and jax.process_index() == 0: - # take leftover samples - batch = eval_dataset[-num_leftover_samples:] - batch = {k: jnp.array(v) for k, v in batch.items()} + train_metrics = [] - labels = batch.pop("labels") - predictions = eval_step(unreplicate(state), batch) - metric.add_batch(predictions=predictions, references=labels) + if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0: - eval_metric = metric.compute() - logger.info(f" Done! Eval metrics: {eval_metric}") + eval_metrics = {} + # evaluate + eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size) + for batch in tqdm( + eval_loader, + total=len(eval_dataset) // eval_batch_size, + desc="Evaluating ...", + position=2, + ): + labels = batch.pop("labels") + predictions = p_eval_step(state, batch) + metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) - cur_step = epoch * (len(train_dataset) // train_batch_size) + # evaluate also on leftover examples (not divisible by batch_size) + num_leftover_samples = len(eval_dataset) % eval_batch_size - # Save metrics - if has_tensorboard and jax.process_index() == 0: - write_metric(train_metrics, eval_metric, train_time, cur_step) + # make sure leftover batch is evaluated on one device + if num_leftover_samples > 0 and jax.process_index() == 0: + # take leftover samples + batch = eval_dataset[-num_leftover_samples:] + batch = {k: np.array(v) for k, v in batch.items()} - # save checkpoint after each epoch and push checkpoint to the hub - if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained(args.output_dir, params=params) - tokenizer.save_pretrained(args.output_dir) - if args.push_to_hub: - repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) + labels = batch.pop("labels") + predictions = eval_step(unreplicate(state), batch) + metric.add_batch(predictions=predictions, references=labels) + + eval_metric = metric.compute() + + logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})") + + if has_tensorboard and jax.process_index() == 0: + write_eval_metric(summary_writer, eval_metrics, cur_step) + + if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(unreplicate(state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) + epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # save the eval metrics in json if jax.process_index() == 0: eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()} - path = os.path.join(args.output_dir, "eval_results.json") + path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metric, f, indent=4, sort_keys=True)