From 1c191efc3abc391072ff0094a8108459bc08e3fa Mon Sep 17 00:00:00 2001 From: Kamal Raj Date: Thu, 9 Sep 2021 10:12:57 +0530 Subject: [PATCH] flax ner example (#13365) * flax ner example * added task to README * updated readme * 1. ArgumentParser -> HfArgumentParser 2. step-wise logging,eval and save * added requirements.txt * added progress bar * updated README * added check_min_version * updated training data permuattion with JAX * added metric lib to requirements * updated readme table * fixed imports --- examples/flax/token-classification/README.md | 74 ++ .../token-classification/requirements.txt | 6 + .../flax/token-classification/run_flax_ner.py | 669 ++++++++++++++++++ 3 files changed, 749 insertions(+) create mode 100644 examples/flax/token-classification/README.md create mode 100644 examples/flax/token-classification/requirements.txt create mode 100644 examples/flax/token-classification/run_flax_ner.py diff --git a/examples/flax/token-classification/README.md b/examples/flax/token-classification/README.md new file mode 100644 index 0000000000..34f156e1c4 --- /dev/null +++ b/examples/flax/token-classification/README.md @@ -0,0 +1,74 @@ + + +# Token classification examples + +Fine-tuning the library models for token classification task such as Named Entity Recognition (NER), Parts-of-speech tagging (POS) or phrase extraction (CHUNKS). The main script run_flax_ner.py leverages the 🤗 Datasets library. You can easily customize it to your needs if you need extra processing on your datasets. + +It will either run on a datasets hosted on our hub or with your own text files for training and validation, you might just need to add some tweaks in the data preprocessing. + +The following example fine-tunes BERT on CoNLL-2003: + +To begin with it is recommended to create a model repository to save the trained model and logs. +Here we call the model `"bert-ner-conll2003-test"`, but you can change the model name as you like. + +You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that +you are logged in) or via the command line: + +``` +huggingface-cli repo create bert-ner-conll2003-test +``` + +Next we clone the model repository to add the tokenizer and model files. + +``` +git clone https://huggingface.co//bert-ner-conll2003-test +``` + +Great, we have set up our model repository. During training, we will automatically +push the training logs and model weights to the repo. + +Next, let's add a symbolic link to the `run_flax_ner.py`. + +```bash +export MODEL_DIR="./bert-ner-conll2003-test" +ln -s ~/transformers/examples/flax/token-classification/run_flax_ner.py run_flax_ner.py +``` + +```bash +python run_flax_ner.py \ + --model_name_or_path bert-base-cased \ + --dataset_name conll2003 \ + --max_seq_length 128 \ + --learning_rate 2e-5 \ + --num_train_epochs 3 \ + --per_device_train_batch_size 4 \ + --output_dir ${MODEL_DIR} \ + --eval_steps 300 \ + --push_to_hub +``` + +Using the command above, the script will train for 3 epochs and run eval after each epoch. +Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`. +You can see the results by running `tensorboard` in that directory: + +```bash +$ tensorboard --logdir . +``` + +or directly on the hub under *Training metrics*. + +sample Metrics - [tfhub.dev](https://tensorboard.dev/experiment/u52qsBIpQSKEEXEJd2LVYA) \ No newline at end of file diff --git a/examples/flax/token-classification/requirements.txt b/examples/flax/token-classification/requirements.txt new file mode 100644 index 0000000000..e3fbc46aa1 --- /dev/null +++ b/examples/flax/token-classification/requirements.txt @@ -0,0 +1,6 @@ +datasets >= 1.8.0 +jax>=0.2.8 +jaxlib>=0.1.59 +flax>=0.3.4 +optax>=0.0.8 +seqeval \ No newline at end of file diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py new file mode 100644 index 0000000000..1596b80e1c --- /dev/null +++ b/examples/flax/token-classification/run_flax_ner.py @@ -0,0 +1,669 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)""" +import logging +import os +import random +import sys +import time +from dataclasses import dataclass, field +from itertools import chain +from typing import Any, Callable, Dict, Optional, Tuple + +import datasets +import numpy as np +from datasets import ClassLabel, load_dataset, load_metric +from tqdm import tqdm + +import jax +import jax.numpy as jnp +import optax +import transformers +from flax import struct, traverse_util +from flax.jax_utils import replicate, unreplicate +from flax.metrics import tensorboard +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard +from transformers import ( + AutoConfig, + AutoTokenizer, + FlaxAutoModelForTokenClassification, + HfArgumentParser, + TrainingArguments, +) +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +logger = logging.getLogger(__name__) +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.11.0.dev0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") + +Array = Any +Dataset = datasets.arrow_dataset.Dataset +PRNGKey = Any + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a csv or JSON file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, + ) + text_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} + ) + label_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: int = field( + default=None, + metadata={ + "help": "The maximum total input sequence length after tokenization. If set, sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + label_all_tokens: bool = field( + default=False, + metadata={ + "help": "Whether to put the label for one word on all tokens of generated by that word or just on the " + "one (in which case the other tokens will have a padding index)." + }, + ) + return_entity_level_metrics: bool = field( + default=False, + metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + self.task_name = self.task_name.lower() + + +def create_train_state( + model: FlaxAutoModelForTokenClassification, + learning_rate_fn: Callable[[int], float], + num_labels: int, + training_args: TrainingArguments, +) -> train_state.TrainState: + """Create initial training state.""" + + class TrainState(train_state.TrainState): + """Train state with an Optax optimizer. + + The two functions below differ depending on whether the task is classification + or regression. + + Args: + logits_fn: Applied to last layer to obtain the logits. + loss_fn: Function to compute the loss. + """ + + logits_fn: Callable = struct.field(pytree_node=False) + loss_fn: Callable = struct.field(pytree_node=False) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxBERT-like models. + # For other models, one should correct the layer norm parameter naming + # accordingly. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + tx = optax.adamw( + learning_rate=learning_rate_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + def cross_entropy_loss(logits, labels): + xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels)) + return jnp.mean(xentropy) + + return TrainState.create( + apply_fn=model.__call__, + params=model.params, + tx=tx, + logits_fn=lambda logits: logits.argmax(-1), + loss_fn=cross_entropy_loss, + ) + + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): + """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices.""" + steps_per_epoch = len(dataset) // batch_size + perms = jax.random.permutation(rng, len(dataset)) + perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch. + perms = perms.reshape((steps_per_epoch, batch_size)) + + for perm in perms: + batch = dataset[perm] + batch = {k: np.array(v) for k, v in batch.items()} + batch = shard(batch) + + yield batch + + +def eval_data_collator(dataset: Dataset, batch_size: int): + """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" + for i in range(len(dataset) // batch_size): + batch = dataset[i * batch_size : (i + 1) * batch_size] + batch = {k: np.array(v) for k, v in batch.items()} + batch = shard(batch) + + yield batch + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called + # 'tokens' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir + ) + else: + # Loading the dataset from local csv or json file. + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] + raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + if raw_datasets["train"] is not None: + column_names = raw_datasets["train"].column_names + features = raw_datasets["train"].features + else: + column_names = raw_datasets["validation"].column_names + features = raw_datasets["validation"].features + + if data_args.text_column_name is not None: + text_column_name = data_args.text_column_name + elif "tokens" in column_names: + text_column_name = "tokens" + else: + text_column_name = column_names[0] + + if data_args.label_column_name is not None: + label_column_name = data_args.label_column_name + elif f"{data_args.task_name}_tags" in column_names: + label_column_name = f"{data_args.task_name}_tags" + else: + label_column_name = column_names[1] + + # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the + # unique labels. + def get_label_list(labels): + unique_labels = set() + for label in labels: + unique_labels = unique_labels | set(label) + label_list = list(unique_labels) + label_list.sort() + return label_list + + if isinstance(features[label_column_name].feature, ClassLabel): + label_list = features[label_column_name].feature.names + # No need to convert the labels since they are already ints. + label_to_id = {i: i for i in range(len(label_list))} + else: + label_list = get_label_list(raw_datasets["train"][label_column_name]) + label_to_id = {l: i for i, l in enumerate(label_list)} + num_labels = len(label_list) + + # Load pretrained model and tokenizer + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + label2id=label_to_id, + id2label={i: l for l, i in label_to_id.items()}, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path + if config.model_type in {"gpt2", "roberta"}: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + add_prefix_space=True, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = FlaxAutoModelForTokenClassification.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Preprocessing the datasets + # Tokenize all texts and align the labels with them. + def tokenize_and_align_labels(examples): + tokenized_inputs = tokenizer( + examples[text_column_name], + max_length=data_args.max_seq_length, + padding="max_length", + truncation=True, + # We use this argument because the texts in our dataset are lists of words (with a label for each word). + is_split_into_words=True, + ) + + labels = [] + + for i, label in enumerate(examples[label_column_name]): + word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + # Special tokens have a word id that is None. We set the label to -100 so they are automatically + # ignored in the loss function. + if word_idx is None: + label_ids.append(-100) + # We set the label for the first token of each word. + elif word_idx != previous_word_idx: + label_ids.append(label_to_id[label[word_idx]]) + # For the other tokens in a word, we set the label to either the current label or -100, depending on + # the label_all_tokens flag. + else: + label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100) + previous_word_idx = word_idx + + labels.append(label_ids) + tokenized_inputs["labels"] = labels + return tokenized_inputs + + processed_raw_datasets = raw_datasets.map( + tokenize_and_align_labels, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=raw_datasets["train"].column_names, + desc="Running tokenizer on dataset", + ) + + train_dataset = processed_raw_datasets["train"] + eval_dataset = processed_raw_datasets["validation"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Define a summary writer + summary_writer = tensorboard.SummaryWriter(training_args.output_dir) + summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) + + def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + def write_eval_metric(summary_writer, eval_metrics, step): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + num_epochs = int(training_args.num_train_epochs) + rng = jax.random.PRNGKey(training_args.seed) + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + + train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() + eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() + + learning_rate_fn = create_learning_rate_fn( + len(train_dataset), + train_batch_size, + training_args.num_train_epochs, + training_args.warmup_steps, + training_args.learning_rate, + ) + + state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args) + + # define step functions + def train_step( + state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey + ) -> Tuple[train_state.TrainState, float]: + """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) + targets = batch.pop("labels") + + def loss_fn(params): + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss = state.loss_fn(logits, targets) + return loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + new_state = state.apply_gradients(grads=grad) + metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") + return new_state, metrics, new_dropout_rng + + p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) + + def eval_step(state, batch): + logits = state.apply_fn(**batch, params=state.params, train=False)[0] + return state.logits_fn(logits) + + p_eval_step = jax.pmap(eval_step, axis_name="batch") + + metric = load_metric("seqeval") + + def get_labels(y_pred, y_true): + # Transform predictions and references tensos to numpy arrays + + # Remove ignored index (special tokens) + true_predictions = [ + [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100] + for pred, gold_label in zip(y_pred, y_true) + ] + true_labels = [ + [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100] + for pred, gold_label in zip(y_pred, y_true) + ] + return true_predictions, true_labels + + def compute_metrics(): + results = metric.compute() + if data_args.return_entity_level_metrics: + # Unpack nested dictionaries + final_results = {} + for key, value in results.items(): + if isinstance(value, dict): + for n, v in value.items(): + final_results[f"{key}_{n}"] = v + else: + final_results[key] = value + return final_results + else: + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + logger.info(f"===== Starting training ({num_epochs} epochs) =====") + train_time = 0 + + # make sure weights are replicated on each device + state = replicate(state) + + train_time = 0 + step_per_epoch = len(train_dataset) // train_batch_size + total_steps = step_per_epoch * num_epochs + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + + train_start = time.time() + train_metrics = [] + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + # train + for step, batch in enumerate( + tqdm( + train_data_collator(input_rng, train_dataset, train_batch_size), + total=step_per_epoch, + desc="Training...", + position=1, + ) + ): + state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) + train_metrics.append(train_metric) + + cur_step = epoch * step_per_epoch + step + + if cur_step % training_args.logging_steps == 0 and cur_step > 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + if jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" + ) + + train_metrics = [] + + if cur_step % training_args.eval_steps == 0 and cur_step > 0: + + eval_metrics = {} + # evaluate + for batch in tqdm( + eval_data_collator(eval_dataset, eval_batch_size), + total=len(eval_dataset) // eval_batch_size, + desc="Evaluating ...", + position=2, + ): + labels = batch.pop("labels") + predictions = p_eval_step(state, batch) + predictions = np.array([pred for pred in chain(*predictions)]) + labels = np.array([label for label in chain(*labels)]) + labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 + preds, refs = get_labels(predictions, labels) + metric.add_batch( + predictions=preds, + references=refs, + ) + + # evaluate also on leftover examples (not divisible by batch_size) + num_leftover_samples = len(eval_dataset) % eval_batch_size + + # make sure leftover batch is evaluated on one device + if num_leftover_samples > 0 and jax.process_index() == 0: + # take leftover samples + batch = eval_dataset[-num_leftover_samples:] + batch = {k: np.array(v) for k, v in batch.items()} + + labels = batch.pop("labels") + predictions = eval_step(unreplicate(state), batch) + labels = np.array(labels) + labels[np.array(batch["attention_mask"]) == 0] = -100 + preds, refs = get_labels(predictions, labels) + metric.add_batch( + predictions=preds, + references=refs, + ) + + eval_metrics = compute_metrics() + + if data_args.return_entity_level_metrics: + logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}") + else: + logger.info( + f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})" + ) + + if jax.process_index() == 0: + write_eval_metric(summary_writer, eval_metrics, cur_step) + + if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(unreplicate(state.params)) + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=f"Saving weights and logs of step {cur_step}", + ) + epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" + + +if __name__ == "__main__": + main()