From 75627148ee131ad274360633686660d59335cc02 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Wed, 9 Dec 2020 17:13:56 +0100 Subject: [PATCH] Flax Masked Language Modeling training example (#8728) * Remove "Model" suffix from Flax models to look more :hugs: Signed-off-by: Morgan Funtowicz * Initial working (forward + backward) for Flax MLM training example. Signed-off-by: Morgan Funtowicz * Simply code Signed-off-by: Morgan Funtowicz * Addressing comments, using module and moving to LM task. Signed-off-by: Morgan Funtowicz * Restore parameter name "module" wrongly renamed model. Signed-off-by: Morgan Funtowicz * Restore correct output ordering... Signed-off-by: Morgan Funtowicz * Actually commit the example :sweat_smile: Signed-off-by: Morgan Funtowicz * Add FlaxBertModelForMaskedLM after rebasing. Signed-off-by: Morgan Funtowicz * Make it possible to initialize the training from scratch Signed-off-by: Morgan Funtowicz * Reuse flax linen example of cross entropy loss Signed-off-by: Morgan Funtowicz * Added specific data collator for flax Signed-off-by: Morgan Funtowicz * Remove todo for data collator Signed-off-by: Morgan Funtowicz * Added evaluation step Signed-off-by: Morgan Funtowicz * Added ability to provide dtype to support bfloat16 on TPU Signed-off-by: Morgan Funtowicz * Enable flax tensorboard output Signed-off-by: Morgan Funtowicz * Enable jax.pmap support. Signed-off-by: Morgan Funtowicz * Ensure batches are correctly sized to be dispatched with jax.pmap Signed-off-by: Morgan Funtowicz * Enable bfloat16 with --fp16 cmdline args Signed-off-by: Morgan Funtowicz * Correctly export metrics to tensorboard Signed-off-by: Morgan Funtowicz * Added dropout and ability to use it. Signed-off-by: Morgan Funtowicz * Effectively enable & disable during training and evaluation steps. Signed-off-by: Morgan Funtowicz * Oops. Signed-off-by: Morgan Funtowicz * Enable specifying kernel initializer scale Signed-off-by: Morgan Funtowicz * Style. Signed-off-by: Morgan Funtowicz * Added warmup step to the learning rate scheduler. Signed-off-by: Morgan Funtowicz * Fix typo. Signed-off-by: Morgan Funtowicz * Print training loss Signed-off-by: Morgan Funtowicz * Make style Signed-off-by: Morgan Funtowicz * fix linter issue (flake8) Signed-off-by: Morgan Funtowicz * Fix model matching Signed-off-by: Morgan Funtowicz * Fix dummies Signed-off-by: Morgan Funtowicz * Fix non default dtype on Flax models Signed-off-by: Morgan Funtowicz * Use the same create_position_ids_from_input_ids for FlaxRoberta Signed-off-by: Morgan Funtowicz * Make Roberta attention as Bert Signed-off-by: Morgan Funtowicz * fix copy Signed-off-by: Morgan Funtowicz * Wording. Co-authored-by: Marc van Zee Co-authored-by: Marc van Zee --- examples/language-modeling/run_mlm_flax.py | 636 ++++++++++++++++++ src/transformers/__init__.py | 2 +- src/transformers/modeling_flax_utils.py | 18 +- src/transformers/models/bert/__init__.py | 2 +- .../models/bert/modeling_flax_bert.py | 361 ++++++++-- .../models/roberta/modeling_flax_roberta.py | 289 ++++++-- src/transformers/utils/dummy_flax_objects.py | 9 + tests/test_modeling_flax_roberta.py | 2 +- 8 files changed, 1187 insertions(+), 132 deletions(-) create mode 100644 examples/language-modeling/run_mlm_flax.py diff --git a/examples/language-modeling/run_mlm_flax.py b/examples/language-modeling/run_mlm_flax.py new file mode 100644 index 0000000000..0c2f0622a3 --- /dev/null +++ b/examples/language-modeling/run_mlm_flax.py @@ -0,0 +1,636 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a +text file or a dataset. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=masked-lm +""" +import logging +import os +import sys +from dataclasses import dataclass, field + +# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +from datasets import load_dataset +from tqdm import tqdm + +import jax +import jax.numpy as jnp +from flax import jax_utils +from flax.optim import Adam +from flax.training import common_utils +from flax.training.common_utils import get_metrics +from jax.nn import log_softmax +from transformers import ( + CONFIG_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + AutoConfig, + AutoTokenizer, + FlaxBertForMaskedLM, + HfArgumentParser, + PreTrainedTokenizerBase, + TensorType, + TrainingArguments, + is_tensorboard_available, + set_seed, +) + + +# Cache the result +has_tensorboard = is_tensorboard_available() +if has_tensorboard: + try: + from flax.metrics.tensorboard import SummaryWriter + except ImportError as ie: + has_tensorboard = False + print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}") + +else: + print( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + train_ref_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input train ref data file for whole word masking in Chinese."}, + ) + validation_ref_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated. Default to the max input length of the model." + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + mlm_probability: float = field( + default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + + +# Adapted from transformers/data/data_collator.py +# Letting here for now, let's discuss where it should live +@dataclass +class FlaxDataCollatorForLanguageModeling: + """ + Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they + are not all of the same length. + + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + mlm (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the + inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for + non-masked tokens and the value to predict for the masked token. + mlm_probability (:obj:`float`, `optional`, defaults to 0.15): + The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`. + + .. note:: + + For best performance, this data collator should be used with a dataset having items that are dictionaries or + BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a + :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the + argument :obj:`return_special_tokens_mask=True`. + """ + + tokenizer: PreTrainedTokenizerBase + mlm: bool = True + mlm_probability: float = 0.15 + + def __post_init__(self): + if self.mlm and self.tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language modeling. " + "You should pass `mlm=False` to train on causal language modeling instead." + ) + + def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]: + # Handle dict or lists with proper padding and conversion to tensor. + batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY) + + # If special token mask has been preprocessed, pop it from the dict. + special_tokens_mask = batch.pop("special_tokens_mask", None) + if self.mlm: + batch["input_ids"], batch["labels"] = self.mask_tokens( + batch["input_ids"], special_tokens_mask=special_tokens_mask + ) + else: + labels = batch["input_ids"].copy() + if self.tokenizer.pad_token_id is not None: + labels[labels == self.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + return batch + + def mask_tokens( + self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray] + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + """ + labels = inputs.copy() + # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) + probability_matrix = np.full(labels.shape, self.mlm_probability) + special_tokens_mask = special_tokens_mask.astype("bool") + + probability_matrix[special_tokens_mask] = 0.0 + masked_indices = np.random.binomial(1, probability_matrix).astype("bool") + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices + inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool") + indices_random &= masked_indices & ~indices_replaced + + random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4") + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels + + +def create_learning_rate_scheduler( + factors="constant * linear_warmup * rsqrt_decay", + base_learning_rate=0.5, + warmup_steps=1000, + decay_factor=0.5, + steps_per_decay=20000, + steps_per_cycle=100000, +): + """Creates learning rate schedule. + Interprets factors in the factors string which can consist of: + * constant: interpreted as the constant value, + * linear_warmup: interpreted as linear warmup until warmup_steps, + * rsqrt_decay: divide by square root of max(step, warmup_steps) + * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) + * decay_every: Every k steps decay the learning rate by decay_factor. + * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. + Args: + factors: string, factors separated by "*" that defines the schedule. + base_learning_rate: float, the starting constant for the lr schedule. + warmup_steps: int, how many steps to warm up for in the warmup schedule. + decay_factor: float, the amount to decay the learning rate by. + steps_per_decay: int, how often to decay the learning rate. + steps_per_cycle: int, steps per cycle when using cosine decay. + Returns: + a function learning_rate(step): float -> {"learning_rate": float}, the + step-dependent lr. + """ + factors = [n.strip() for n in factors.split("*")] + + def step_fn(step): + """Step to learning rate function.""" + ret = 1.0 + for name in factors: + if name == "constant": + ret *= base_learning_rate + elif name == "linear_warmup": + ret *= jnp.minimum(1.0, step / warmup_steps) + elif name == "rsqrt_decay": + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == "rsqrt_normalized_decay": + ret *= jnp.sqrt(warmup_steps) + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == "decay_every": + ret *= decay_factor ** (step // steps_per_decay) + elif name == "cosine_decay": + progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) + ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) + else: + raise ValueError("Unknown factor %s." % name) + return jnp.asarray(ret, dtype=jnp.float32) + + return step_fn + + +def compute_metrics(logits, labels, weights, label_smoothing=0.0): + """Compute summary metrics.""" + loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing) + acc, _ = accuracy(logits, labels, weights) + metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer} + metrics = jax.lax.psum(metrics, axis_name="batch") + return metrics + + +def accuracy(logits, targets, weights=None): + """Compute weighted accuracy for log probs and targets. + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: None or array of shape [batch, length] + Returns: + Tuple of scalar loss and batch normalizing factor. + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) + ) + + loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) + loss *= weights + + return loss.sum(), weights.sum() + + +def cross_entropy(logits, targets, weights=None, label_smoothing=0.0): + """Compute cross entropy and entropy for log probs and targets. + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: None or array of shape [batch, length] + label_smoothing: label smoothing constant, used to determine the on and off values. + Returns: + Tuple of scalar loss and batch normalizing factor. + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) + ) + + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1) + loss = loss - normalizing_constant + + if weights is not None: + loss = loss * weights + normalizing_factor = weights.sum() + else: + normalizing_factor = np.prod(targets.shape) + + return loss.sum(), normalizing_factor + + +def training_step(optimizer, batch, dropout_rng): + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) + + def loss_fn(params): + targets = batch.pop("labels") + + # Hide away tokens which doesn't participate in the optimization + token_mask = jnp.where(targets > 0, 1.0, 0.0) + + pooled, logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True) + loss, weight_sum = cross_entropy(logits, targets, token_mask) + return loss / weight_sum + + step = optimizer.state.step + lr = lr_scheduler_fn(step) + grad_fn = jax.value_and_grad(loss_fn) + loss, grad = grad_fn(optimizer.target) + grad = jax.lax.pmean(grad, "batch") + optimizer = optimizer.apply_gradient(grad, learning_rate=lr) + + return loss, optimizer, new_dropout_rng + + +def eval_step(params, batch): + """ + Calculate evaluation metrics on a batch. + """ + targets = batch.pop("labels") + + # Hide away tokens which doesn't participate in the optimization + token_mask = jnp.where(targets > 0, 1.0, 0.0) + _, logits = model(**batch, params=params, train=False) + + return compute_metrics(logits, targets, token_mask) + + +def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: + nb_samples = len(samples_idx) + samples_to_remove = nb_samples % batch_size + + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = nb_samples // batch_size + batch_idx = jnp.split(samples_idx, sections_split) + return batch_idx + + +if __name__ == "__main__": + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + level="NOTSET", + datefmt="[%X]", + ) + + # Log on each process the small summary: + logger = logging.getLogger(__name__) + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + datasets = load_dataset(extension, data_files=data_files) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + # Preprocessing the datasets. + # First we tokenize all the texts. + if training_args.do_train: + column_names = datasets["train"].column_names + else: + column_names = datasets["validation"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + padding = "max_length" if data_args.pad_to_max_length else False + + def tokenize_function(examples): + # Remove empty lines + examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] + return tokenizer( + examples["text"], + return_special_tokens_mask=True, + padding=padding, + truncation=True, + max_length=data_args.max_seq_length, + ) + + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=[text_column_name], + load_from_cache_file=not data_args.overwrite_cache, + ) + + # Enable tensorboard only on the master node + if has_tensorboard and jax.host_id() == 0: + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) + + # Data collator + # This one will take care of randomly masking the tokens. + data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) + + # Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + + model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased", dtype=jnp.float32, dropout_rate=0.1) + model.init(jax.random.PRNGKey(training_args.seed), (training_args.train_batch_size, model.config.max_length)) + + # Setup optimizer + optimizer = Adam( + learning_rate=training_args.learning_rate, + weight_decay=training_args.weight_decay, + beta1=training_args.adam_beta1, + beta2=training_args.adam_beta2, + ).create(model.params) + + # Create learning rate scheduler + lr_scheduler_fn = create_learning_rate_scheduler( + base_learning_rate=training_args.learning_rate, warmup_steps=training_args.warmup_steps + ) + + # Create parallel version of the training and evaluation steps + p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,)) + p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) + + # Replicate the optimizer on each device + optimizer = jax_utils.replicate(optimizer) + + # Store some constant + nb_epochs = int(training_args.num_train_epochs) + batch_size = int(training_args.train_batch_size) + eval_batch_size = int(training_args.eval_batch_size) + + epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0) + for epoch in epochs: + + # ======================== Training ================================ + # Create sampling rng + rng, training_rng, eval_rng = jax.random.split(rng, 3) + + # Generate an epoch by shuffling sampling indices from the train dataset + nb_training_samples = len(tokenized_datasets["train"]) + training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples)) + training_batch_idx = generate_batch_splits(training_samples_idx, batch_size) + + # Gather the indexes for creating the batch and do a training step + for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1): + samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples, pad_to_multiple_of=16) + + # Model forward + model_inputs = common_utils.shard(model_inputs.data) + loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs) + + epochs.write(f"Loss: {loss}") + + # ======================== Evaluating ============================== + nb_eval_samples = len(tokenized_datasets["test"]) + eval_samples_idx = jnp.arange(nb_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + + eval_metrics = [] + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [tokenized_datasets["test"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples, pad_to_multiple_of=16) + + # Model forward + model_inputs = common_utils.shard(model_inputs.data) + metrics = p_eval_step(optimizer.target, model_inputs) + eval_metrics.append(metrics) + + eval_metrics_np = get_metrics(eval_metrics) + eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np) + eval_normalizer = eval_metrics_np.pop("normalizer") + eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np) + + # Update progress bar + epochs.desc = ( + f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})" + ) + + # Save metrics + if has_tensorboard and jax.host_id() == 0: + for name, value in eval_summary.items(): + summary_writer.scalar(name, value, epoch) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8b8deb2b4d..f0e22c58ec 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -936,7 +936,7 @@ else: if is_flax_available(): from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel - from .models.bert import FlaxBertModel + from .models.bert import FlaxBertForMaskedLM, FlaxBertModel from .models.roberta import FlaxRobertaModel else: # Import the same objects as dummies to get them in the namespace. diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 163bb4f2ef..69bf948b08 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -65,13 +65,12 @@ class FlaxPreTrainedModel(ABC): base_model_prefix = "" model_class = None - def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0): + def __init__( + self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0, dtype: jnp.dtype = jnp.float32 + ): if config is None: raise ValueError("config cannot be None") - if module is None: - raise ValueError("module cannot be None") - if params is None: raise ValueError("state cannot be None") @@ -82,19 +81,23 @@ class FlaxPreTrainedModel(ABC): # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.params = params - self.model = module + self.dtype = dtype @property def config(self) -> PretrainedConfig: return self._config + @property + def module(self) -> nn.Module: + return self._module + @staticmethod @abstractmethod def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict: raise NotImplementedError() @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs): r""" Instantiate a pretrained Flax model from a pre-trained model configuration. """ @@ -127,6 +130,9 @@ class FlaxPreTrainedModel(ABC): else: model_kwargs = kwargs + # Add the dtype to model_kwargs + model_kwargs["dtype"] = dtype + # Load model if pretrained_model_name_or_path is not None: if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): diff --git a/src/transformers/models/bert/__init__.py b/src/transformers/models/bert/__init__.py index 8cdd93bbd4..a52d2574d5 100644 --- a/src/transformers/models/bert/__init__.py +++ b/src/transformers/models/bert/__init__.py @@ -59,4 +59,4 @@ if is_tf_available(): ) if is_flax_available(): - from .modeling_flax_bert import FlaxBertModel + from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index d4e9334378..f0f39d9758 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict +from typing import Callable, Dict, Tuple import numpy as np import flax.linen as nn import jax import jax.numpy as jnp +from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_flax_utils import FlaxPreTrainedModel, gelu @@ -101,8 +102,8 @@ class FlaxBertLayerNorm(nn.Module): scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear # (also e.g. nn.relu), this can be disabled since the scaling will be # done by the next layer. - bias_init: jnp.ndarray = nn.initializers.zeros - scale_init: jnp.ndarray = nn.initializers.ones + scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros @nn.compact def __call__(self, x): @@ -122,11 +123,13 @@ class FlaxBertLayerNorm(nn.Module): mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) var = mean2 - jax.lax.square(mean) mul = jax.lax.rsqrt(var + self.epsilon) + if self.scale: - mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype) + mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,))) y = (x - mean) * mul + if self.bias: - y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype) + y = y + jnp.asarray(self.param("beta", self.bias_init, (features,))) return y @@ -138,7 +141,9 @@ class FlaxBertEmbedding(nn.Module): vocab_size: int hidden_size: int - emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) + kernel_init_scale: float = 0.2 + emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale) + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact def __call__(self, inputs): @@ -153,63 +158,105 @@ class FlaxBertEmbeddings(nn.Module): hidden_size: int type_vocab_size: int max_length: int + kernel_init_scale: float = 0.2 + dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): # Embed - w_emb = FlaxBertEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")( - jnp.atleast_2d(input_ids.astype("i4")) - ) - p_emb = FlaxBertEmbedding(self.max_length, self.hidden_size, name="position_embeddings")( - jnp.atleast_2d(position_ids.astype("i4")) - ) - t_emb = FlaxBertEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")( - jnp.atleast_2d(token_type_ids.astype("i4")) - ) + w_emb = FlaxBertEmbedding( + self.vocab_size, + self.hidden_size, + kernel_init_scale=self.kernel_init_scale, + name="word_embeddings", + dtype=self.dtype, + )(jnp.atleast_2d(input_ids.astype("i4"))) + p_emb = FlaxBertEmbedding( + self.max_length, + self.hidden_size, + kernel_init_scale=self.kernel_init_scale, + name="position_embeddings", + dtype=self.dtype, + )(jnp.atleast_2d(position_ids.astype("i4"))) + t_emb = FlaxBertEmbedding( + self.type_vocab_size, + self.hidden_size, + kernel_init_scale=self.kernel_init_scale, + name="token_type_embeddings", + dtype=self.dtype, + )(jnp.atleast_2d(token_type_ids.astype("i4"))) # Sum all embeddings summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb # Layer Norm - layer_norm = FlaxBertLayerNorm(name="layer_norm")(summed_emb) - - return layer_norm + layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb) + embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic) + return embeddings class FlaxBertAttention(nn.Module): num_heads: int head_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask): + def __call__(self, hidden_state, attention_mask, deterministic: bool = True): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( - hidden_state, attention_mask - ) + self_att = nn.attention.SelfAttention( + num_heads=self.num_heads, + qkv_features=self.head_size, + dropout_rate=self.dropout_rate, + deterministic=deterministic, + kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + bias_init=jax.nn.initializers.zeros, + name="self", + dtype=self.dtype, + )(hidden_state, attention_mask) - layer_norm = FlaxBertLayerNorm(name="layer_norm")(self_att + hidden_state) + layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state) return layer_norm class FlaxBertIntermediate(nn.Module): output_size: int + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact def __call__(self, hidden_state): # TODO: Add ACT2FN reference to change activation function - dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) + dense = nn.Dense( + features=self.output_size, + kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + name="dense", + dtype=self.dtype, + )(hidden_state) return gelu(dense) class FlaxBertOutput(nn.Module): + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + @nn.compact - def __call__(self, intermediate_output, attention_output): - hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) - hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output) + def __call__(self, intermediate_output, attention_output, deterministic: bool = True): + hidden_state = nn.Dense( + attention_output.shape[-1], + kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + name="dense", + dtype=self.dtype, + )(intermediate_output) + hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic) + hidden_state = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output) return hidden_state @@ -217,12 +264,26 @@ class FlaxBertLayer(nn.Module): num_heads: int head_size: int intermediate_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask): - attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask) - intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention) - output = FlaxBertOutput(name="output")(intermediate, attention) + def __call__(self, hidden_state, attention_mask, deterministic: bool = True): + attention = FlaxBertAttention( + self.num_heads, + self.head_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="attention", + dtype=self.dtype, + )(hidden_state, attention_mask, deterministic=deterministic) + intermediate = FlaxBertIntermediate( + self.intermediate_size, kernel_init_scale=self.kernel_init_scale, name="intermediate", dtype=self.dtype + )(attention) + output = FlaxBertOutput( + kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype + )(intermediate, attention, deterministic=deterministic) return output @@ -236,9 +297,12 @@ class FlaxBertLayerCollection(nn.Module): num_heads: int head_size: int intermediate_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, inputs, attention_mask): + def __call__(self, inputs, attention_mask, deterministic: bool = True): assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" # Initialize input / output @@ -246,8 +310,16 @@ class FlaxBertLayerCollection(nn.Module): # Forward over all encoders for i in range(self.num_layers): - layer = FlaxBertLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}") - input_i = layer(input_i, attention_mask) + layer = FlaxBertLayer( + self.num_heads, + self.head_size, + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name=f"{i}", + dtype=self.dtype, + ) + input_i = layer(input_i, attention_mask, deterministic=deterministic) return input_i @@ -256,21 +328,39 @@ class FlaxBertEncoder(nn.Module): num_heads: int head_size: int intermediate_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask): + def __call__(self, hidden_state, attention_mask, deterministic: bool = True): layer = FlaxBertLayerCollection( - self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" - )(hidden_state, attention_mask) + self.num_layers, + self.num_heads, + self.head_size, + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="layer", + dtype=self.dtype, + )(hidden_state, attention_mask, deterministic=deterministic) return layer class FlaxBertPooler(nn.Module): + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + @nn.compact def __call__(self, hidden_state): cls_token = hidden_state[:, 0] - out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) - return jax.lax.tanh(out) + out = nn.Dense( + hidden_state.shape[-1], + kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + name="dense", + dtype=self.dtype, + )(cls_token) + return nn.tanh(out) class FlaxBertModule(nn.Module): @@ -282,24 +372,104 @@ class FlaxBertModule(nn.Module): num_heads: int head_size: int intermediate_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, input_ids, attention_mask, token_type_ids, position_ids): + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): # Embedding embeddings = FlaxBertEmbeddings( - self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings" - )(input_ids, token_type_ids, position_ids, attention_mask) + self.vocab_size, + self.hidden_size, + self.type_vocab_size, + self.max_length, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="embeddings", + dtype=self.dtype, + )(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) # N stacked encoding layers encoder = FlaxBertEncoder( - self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder" - )(embeddings, attention_mask) + self.num_encoder_layers, + self.num_heads, + self.head_size, + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="encoder", + dtype=self.dtype, + )(embeddings, attention_mask, deterministic=deterministic) - pooled = FlaxBertPooler(name="pooler")(encoder) + pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) return encoder, pooled +class FlaxBertPredictionHeadTransform(nn.Module): + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, hidden_states): + hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states) + hidden_states = nn.elu(hidden_states) # TODO: ACT2FN[config.hidden_act] + return FlaxBertLayerNorm(name="LayerNorm", dtype=self.dtype)(hidden_states) + + +class FlaxBertLMPredictionHead(nn.Module): + vocab_size: int + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, hidden_states): + # TODO: The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + # Need a link between the two variables so that the bias is correctly + # resized with `resize_token_embeddings` + + hidden_states = FlaxBertPredictionHeadTransform(name="transform", dtype=self.dtype)(hidden_states) + hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states) + return hidden_states + + +class FlaxBertOnlyMLMHead(nn.Module): + vocab_size: int + hidden_size: int + intermediate_size: int + head_size: int + num_heads: int + num_encoder_layers: int + type_vocab_size: int + max_length: int + dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + encoder, pooled = FlaxBertModule( + vocab_size=self.vocab_size, + type_vocab_size=self.type_vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + head_size=self.hidden_size, + num_heads=self.num_heads, + num_encoder_layers=self.num_encoder_layers, + max_length=self.max_length, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + )(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + + # Compute the prediction scores + encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic) + logits = FlaxBertLMPredictionHead(vocab_size=self.vocab_size, name="predictions", dtype=self.dtype)(encoder) + + return logits, pooled + + @add_start_docstrings( "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", BERT_START_DOCSTRING, @@ -385,8 +555,8 @@ class FlaxBertModel(FlaxPreTrainedModel): return jax_state - def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs): - model = FlaxBertModule( + def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32): + module = FlaxBertModule( vocab_size=config.vocab_size, hidden_size=config.hidden_size, type_vocab_size=config.type_vocab_size, @@ -395,16 +565,43 @@ class FlaxBertModel(FlaxPreTrainedModel): num_heads=config.num_attention_heads, head_size=config.hidden_size, intermediate_size=config.intermediate_size, + dropout_rate=config.hidden_dropout_prob, + dtype=dtype, ) - super().__init__(config, model, state, seed) - - @property - def module(self) -> nn.Module: - return self._module + super().__init__(config, module, state, seed) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None): + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids): if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) @@ -414,10 +611,62 @@ class FlaxBertModel(FlaxPreTrainedModel): if attention_mask is None: attention_mask = jnp.ones_like(input_ids) - return self.model.apply( - {"params": self.params}, + return input_ids, attention_mask, token_type_ids, position_ids + + def init(self, rng: jax.random.PRNGKey, input_shape: Tuple): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + jnp.zeros(input_shape, dtype="i4"), None, None, None + ) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] + + +class FlaxBertForMaskedLM(FlaxBertModel): + def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): + super().__init__(config, state, seed, dtype) + + self._module = FlaxBertOnlyMLMHead( + vocab_size=config.vocab_size, + type_vocab_size=config.type_vocab_size, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + head_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_encoder_layers=config.num_hidden_layers, + max_length=config.max_length, + **kwargs, + ) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + pooled, logits = self.module.apply( + {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, ) + + return logits, pooled diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index c21556f03e..bafbdfc4d4 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict +from typing import Callable, Dict, Tuple import numpy as np import flax.linen as nn import jax import jax.numpy as jnp +from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_flax_utils import FlaxPreTrainedModel, gelu @@ -101,8 +102,8 @@ class FlaxRobertaLayerNorm(nn.Module): scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear # (also e.g. nn.relu), this can be disabled since the scaling will be # done by the next layer. - bias_init: jnp.ndarray = nn.initializers.zeros - scale_init: jnp.ndarray = nn.initializers.ones + scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros @nn.compact def __call__(self, x): @@ -122,11 +123,13 @@ class FlaxRobertaLayerNorm(nn.Module): mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) var = mean2 - jax.lax.square(mean) mul = jax.lax.rsqrt(var + self.epsilon) + if self.scale: - mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype) + mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,))) y = (x - mean) * mul + if self.bias: - y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype) + y = y + jnp.asarray(self.param("beta", self.bias_init, (features,))) return y @@ -139,7 +142,9 @@ class FlaxRobertaEmbedding(nn.Module): vocab_size: int hidden_size: int - emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) + kernel_init_scale: float = 0.2 + emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale) + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact def __call__(self, inputs): @@ -155,66 +160,108 @@ class FlaxRobertaEmbeddings(nn.Module): hidden_size: int type_vocab_size: int max_length: int + kernel_init_scale: float = 0.2 + dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): # Embed - w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")( - jnp.atleast_2d(input_ids.astype("i4")) - ) - p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")( - jnp.atleast_2d(position_ids.astype("i4")) - ) - t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")( - jnp.atleast_2d(token_type_ids.astype("i4")) - ) + w_emb = FlaxRobertaEmbedding( + self.vocab_size, + self.hidden_size, + kernel_init_scale=self.kernel_init_scale, + name="word_embeddings", + dtype=self.dtype, + )(jnp.atleast_2d(input_ids.astype("i4"))) + p_emb = FlaxRobertaEmbedding( + self.max_length, + self.hidden_size, + kernel_init_scale=self.kernel_init_scale, + name="position_embeddings", + dtype=self.dtype, + )(jnp.atleast_2d(position_ids.astype("i4"))) + t_emb = FlaxRobertaEmbedding( + self.type_vocab_size, + self.hidden_size, + kernel_init_scale=self.kernel_init_scale, + name="token_type_embeddings", + dtype=self.dtype, + )(jnp.atleast_2d(token_type_ids.astype("i4"))) # Sum all embeddings summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb # Layer Norm - layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb) - - return layer_norm + layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb) + embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic) + return embeddings # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta class FlaxRobertaAttention(nn.Module): num_heads: int head_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask): + def __call__(self, hidden_state, attention_mask, deterministic: bool = True): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( - hidden_state, attention_mask - ) + self_att = nn.attention.SelfAttention( + num_heads=self.num_heads, + qkv_features=self.head_size, + dropout_rate=self.dropout_rate, + deterministic=deterministic, + kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + bias_init=jax.nn.initializers.zeros, + name="self", + dtype=self.dtype, + )(hidden_state, attention_mask) - layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state) + layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state) return layer_norm # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta class FlaxRobertaIntermediate(nn.Module): output_size: int + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact def __call__(self, hidden_state): # TODO: Add ACT2FN reference to change activation function - dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) + dense = nn.Dense( + features=self.output_size, + kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + name="dense", + dtype=self.dtype, + )(hidden_state) return gelu(dense) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta class FlaxRobertaOutput(nn.Module): + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + @nn.compact - def __call__(self, intermediate_output, attention_output): - hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) - hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output) + def __call__(self, intermediate_output, attention_output, deterministic: bool = True): + hidden_state = nn.Dense( + attention_output.shape[-1], + kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + name="dense", + dtype=self.dtype, + )(intermediate_output) + hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic) + hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output) return hidden_state @@ -222,14 +269,29 @@ class FlaxRobertaLayer(nn.Module): num_heads: int head_size: int intermediate_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask): - attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")( - hidden_state, attention_mask - ) - intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention) - output = FlaxRobertaOutput(name="output")(intermediate, attention) + def __call__(self, hidden_state, attention_mask, deterministic: bool = True): + attention = FlaxRobertaAttention( + self.num_heads, + self.head_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="attention", + dtype=self.dtype, + )(hidden_state, attention_mask, deterministic=deterministic) + intermediate = FlaxRobertaIntermediate( + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + name="intermediate", + dtype=self.dtype, + )(attention) + output = FlaxRobertaOutput( + kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype + )(intermediate, attention, deterministic=deterministic) return output @@ -244,9 +306,12 @@ class FlaxRobertaLayerCollection(nn.Module): num_heads: int head_size: int intermediate_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, inputs, attention_mask): + def __call__(self, inputs, attention_mask, deterministic: bool = True): assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" # Initialize input / output @@ -254,8 +319,16 @@ class FlaxRobertaLayerCollection(nn.Module): # Forward over all encoders for i in range(self.num_layers): - layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}") - input_i = layer(input_i, attention_mask) + layer = FlaxRobertaLayer( + self.num_heads, + self.head_size, + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name=f"{i}", + dtype=self.dtype, + ) + input_i = layer(input_i, attention_mask, deterministic=deterministic) return input_i @@ -265,22 +338,40 @@ class FlaxRobertaEncoder(nn.Module): num_heads: int head_size: int intermediate_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, hidden_state, attention_mask): + def __call__(self, hidden_state, attention_mask, deterministic: bool = True): layer = FlaxRobertaLayerCollection( - self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" - )(hidden_state, attention_mask) + self.num_layers, + self.num_heads, + self.head_size, + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="layer", + dtype=self.dtype, + )(hidden_state, attention_mask, deterministic=deterministic) return layer # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta class FlaxRobertaPooler(nn.Module): + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + @nn.compact def __call__(self, hidden_state): cls_token = hidden_state[:, 0] - out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) - return jax.lax.tanh(out) + out = nn.Dense( + hidden_state.shape[-1], + kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + name="dense", + dtype=self.dtype, + )(cls_token) + return nn.tanh(out) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta @@ -293,21 +384,38 @@ class FlaxRobertaModule(nn.Module): num_heads: int head_size: int intermediate_size: int + dropout_rate: float = 0.0 + kernel_init_scale: float = 0.2 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation @nn.compact - def __call__(self, input_ids, attention_mask, token_type_ids, position_ids): + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): # Embedding embeddings = FlaxRobertaEmbeddings( - self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings" - )(input_ids, token_type_ids, position_ids, attention_mask) + self.vocab_size, + self.hidden_size, + self.type_vocab_size, + self.max_length, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="embeddings", + dtype=self.dtype, + )(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) # N stacked encoding layers encoder = FlaxRobertaEncoder( - self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder" - )(embeddings, attention_mask) + self.num_encoder_layers, + self.num_heads, + self.head_size, + self.intermediate_size, + kernel_init_scale=self.kernel_init_scale, + dropout_rate=self.dropout_rate, + name="encoder", + dtype=self.dtype, + )(embeddings, attention_mask, deterministic=deterministic) - pooled = FlaxRobertaPooler(name="pooler")(encoder) + pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) return encoder, pooled @@ -396,8 +504,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel): return jax_state - def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs): - model = FlaxRobertaModule( + def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32): + module = FlaxRobertaModule( vocab_size=config.vocab_size, hidden_size=config.hidden_size, type_vocab_size=config.type_vocab_size, @@ -406,31 +514,78 @@ class FlaxRobertaModel(FlaxPreTrainedModel): num_heads=config.num_attention_heads, head_size=config.hidden_size, intermediate_size=config.intermediate_size, + dropout_rate=config.hidden_dropout_prob, + dtype=dtype, ) - super().__init__(config, model, state, seed) - - @property - def module(self) -> nn.Module: - return self._module + super().__init__(config, module, state, seed) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None): - if token_type_ids is None: - token_type_ids = jnp.ones_like(input_ids) + def __call__( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) - if position_ids is None: - position_ids = jnp.arange( - self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1 - ) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - return self.model.apply( - {"params": self.params}, + return self.module.apply( + {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, ) + + def init(self, rng: jax.random.PRNGKey, input_shape: Tuple): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + jnp.zeros(input_shape, dtype="i4"), None, None, None + ) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] + + def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids): + + if token_type_ids is None: + token_type_ids = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + return input_ids, attention_mask, token_type_ids, position_ids + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: jnp.ndarray + padding_idx: int + + Returns: jnp.ndarray + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = (input_ids != padding_idx).astype("i4") + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + return incremental_indices.astype("i4") + padding_idx diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 84f9853842..3c9b204b10 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -14,6 +14,15 @@ class FlaxAutoModel: requires_flax(self) +class FlaxBertForMaskedLM: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + class FlaxBertModel: def __init__(self, *args, **kwargs): requires_flax(self) diff --git a/tests/test_modeling_flax_roberta.py b/tests/test_modeling_flax_roberta.py index 1d6186f12f..3c60b17ab8 100644 --- a/tests/test_modeling_flax_roberta.py +++ b/tests/test_modeling_flax_roberta.py @@ -57,7 +57,7 @@ class FlaxRobertaModelTest(unittest.TestCase): self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()): - self.assert_almost_equals(fx_output, pt_output.numpy(), 6e-4) + self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) def test_multiple_sequences(self): tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")