From 4eef5889acc42b410d9b2e34463845bc4a96303a Mon Sep 17 00:00:00 2001 From: Teven Date: Mon, 21 Dec 2020 21:19:41 +0100 Subject: [PATCH] Adding performer fine-tuning research exampke (#9239) * added run_mlm_performer.py research example * make styke * make styke * Added a README ! --- .../research_projects/performer/README.md | 25 + .../performer/full_script.sh | 1 + .../performer/modeling_flax_performer.py | 553 ++++++++++++++ .../modeling_flax_performer_utils.py | 660 +++++++++++++++++ .../performer/run_mlm_performer.py | 685 ++++++++++++++++++ .../performer/sanity_script.sh | 1 + 6 files changed, 1925 insertions(+) create mode 100644 examples/research_projects/performer/README.md create mode 100755 examples/research_projects/performer/full_script.sh create mode 100644 examples/research_projects/performer/modeling_flax_performer.py create mode 100644 examples/research_projects/performer/modeling_flax_performer_utils.py create mode 100644 examples/research_projects/performer/run_mlm_performer.py create mode 100755 examples/research_projects/performer/sanity_script.sh diff --git a/examples/research_projects/performer/README.md b/examples/research_projects/performer/README.md new file mode 100644 index 0000000000..42cb6fa358 --- /dev/null +++ b/examples/research_projects/performer/README.md @@ -0,0 +1,25 @@ +# Performer fine-tuning + +Example authors: @TevenLeScao, @Patrickvonplaten + +Paper authors: Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, David Belanger, Lucy Colwell, Adrian Weller + +## Requirements + +`datasets`, `flax` and `jax`. `wandb` integration is built-in if you want to use it. + +## Examples + +`sanity_script.sh` will launch performer fine-tuning from the bert-base-cased checkpoint on the Simple Wikipedia dataset (a small, easy-language English Wikipedia) from `datasets`. +`full_script.sh` will launch performer fine-tuning from the bert-large-cased checkpoint on the English Wikipedia dataset from `datasets`. + +Here are a few key arguments: +- Remove the `--performer` argument to use a standard Bert model. + +- Add `--reinitialize` to start from a blank model rather than a Bert checkpoint. + +- You may change the Bert size by passing a different [checkpoint](https://huggingface.co/transformers/pretrained_models.html) to the `--model_name_or_path` argument. + +- Passing your user name to the `--wandb_user_name` argument will trigger weights and biases logging. + +- You can choose a dataset with `--dataset_name` and `--dataset_config`. Our [viewer](https://huggingface.co/datasets/viewer/) will help you find what you need. \ No newline at end of file diff --git a/examples/research_projects/performer/full_script.sh b/examples/research_projects/performer/full_script.sh new file mode 100755 index 0000000000..8634666f98 --- /dev/null +++ b/examples/research_projects/performer/full_script.sh @@ -0,0 +1 @@ +TOKENIZERS_PARALLELISM=true python run_mlm_performer.py --output_dir experiments --dataset_name wikipedia --dataset_config_name 20200501.en --model_name_or_path bert-large-cased --tokenizer_name bert-large-cased --do_train --overwrite_output_dir --per_device_train_batch_size 4 --learning_rate 5e-4 --warmup_steps 100 --num_train_epochs 3 --performer \ No newline at end of file diff --git a/examples/research_projects/performer/modeling_flax_performer.py b/examples/research_projects/performer/modeling_flax_performer.py new file mode 100644 index 0000000000..b4b9924fae --- /dev/null +++ b/examples/research_projects/performer/modeling_flax_performer.py @@ -0,0 +1,553 @@ +# coding=utf-8 +# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Tuple + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from jax.random import PRNGKey +from modeling_flax_performer_utils import make_fast_softmax_attention +from transformers.file_utils import add_start_docstrings +from transformers.modeling_flax_utils import ACT2FN +from transformers.models.bert.configuration_bert import BertConfig +from transformers.models.bert.modeling_flax_bert import FlaxBertOnlyMLMHead, FlaxBertPreTrainedModel +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +BERT_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class FlaxPerformerLayerNorm(nn.Module): + """ + Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. + """ + + epsilon: float = 1e-6 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + bias: bool = True # If True, bias (beta) is added. + scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear + # (also e.g. nn.relu), this can be disabled since the scaling will be + # done by the next layer. + bias_init: jnp.ndarray = nn.initializers.zeros + scale_init: jnp.ndarray = nn.initializers.ones + + @nn.compact + def __call__(self, x): + """ + Applies layer normalization on the input. It normalizes the activations of the layer for each given example in + a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that + maintains the mean activation within each example close to 0 and the activation standard deviation close to 1 + + Args: + x: the inputs + + Returns: + Normalized inputs (the same shape as inputs). + """ + features = x.shape[-1] + mean = jnp.mean(x, axis=-1, keepdims=True) + mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) + var = mean2 - jax.lax.square(mean) + mul = jax.lax.rsqrt(var + self.epsilon) + if self.scale: + mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype) + y = (x - mean) * mul + if self.bias: + y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype) + return y + + +class FlaxPerformerEmbedding(nn.Module): + """ + Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch + use 'weight' + """ + + vocab_size: int + hidden_size: int + emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) + + @nn.compact + def __call__(self, inputs): + embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) + return jnp.take(embedding, inputs, axis=0) + + +class FlaxPerformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + vocab_size: int + hidden_size: int + type_vocab_size: int + max_length: int + + @nn.compact + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + # Embed + w_emb = FlaxPerformerEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")( + jnp.atleast_2d(input_ids.astype("i4")) + ) + p_emb = FlaxPerformerEmbedding(self.max_length, self.hidden_size, name="position_embeddings")( + jnp.atleast_2d(position_ids.astype("i4")) + ) + t_emb = FlaxPerformerEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")( + jnp.atleast_2d(token_type_ids.astype("i4")) + ) + + # Sum all embeddings + summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb + + # Layer Norm + layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(summed_emb) + + return layer_norm + + +class FlaxPerformerAttention(nn.Module): + num_heads: int + head_size: int + + @nn.compact + def __call__(self, hidden_state, attention_mask): + single_head_dim = self.head_size // self.num_heads + fast_softmax_attention = make_fast_softmax_attention(qkv_dim=single_head_dim) + self_att = nn.attention.SelfAttention( + num_heads=self.num_heads, qkv_features=self.head_size, name="self", attention_fn=fast_softmax_attention + )(hidden_state, attention_mask) + + layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(self_att + hidden_state) + return layer_norm + + +class FlaxPerformerIntermediate(nn.Module): + output_size: int + hidden_act: str = "gelu" + + @nn.compact + def __call__(self, hidden_state): + # TODO: Add ACT2FN reference to change activation function + dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) + return ACT2FN[self.hidden_act](dense) + + +class FlaxPerformerOutput(nn.Module): + @nn.compact + def __call__(self, intermediate_output, attention_output): + hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) + hidden_state = FlaxPerformerLayerNorm(name="layer_norm")(hidden_state + attention_output) + return hidden_state + + +class FlaxPerformerLayer(nn.Module): + num_heads: int + head_size: int + intermediate_size: int + hidden_act: str = "gelu" + + @nn.compact + def __call__(self, hidden_state, attention_mask): + attention = FlaxPerformerAttention(self.num_heads, self.head_size, name="attention")( + hidden_state, attention_mask + ) + intermediate = FlaxPerformerIntermediate( + self.intermediate_size, name="intermediate", hidden_act=self.hidden_act + )(attention) + output = FlaxPerformerOutput(name="output")(intermediate, attention) + + return output + + +class FlaxPerformerLayerCollection(nn.Module): + """ + Stores N BertLayer(s) + """ + + num_layers: int + num_heads: int + head_size: int + intermediate_size: int + hidden_act: str = "gelu" + + @nn.compact + def __call__(self, inputs, attention_mask): + assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" + + # Initialize input / output + input_i = inputs + + # Forward over all encoders + for i in range(self.num_layers): + layer = FlaxPerformerLayer( + self.num_heads, self.head_size, self.intermediate_size, hidden_act=self.hidden_act, name=f"{i}" + ) + input_i = layer(input_i, attention_mask) + return input_i + + +class FlaxPerformerEncoder(nn.Module): + num_layers: int + num_heads: int + head_size: int + intermediate_size: int + hidden_act: str = "gelu" + + @nn.compact + def __call__(self, hidden_state, attention_mask): + layer = FlaxPerformerLayerCollection( + self.num_layers, + self.num_heads, + self.head_size, + self.intermediate_size, + name="layer", + hidden_act=self.hidden_act, + )(hidden_state, attention_mask) + return layer + + +class FlaxPerformerPooler(nn.Module): + @nn.compact + def __call__(self, hidden_state): + cls_token = hidden_state[:, 0] + out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) + return jax.lax.tanh(out) + + +class FlaxPerformerModule(nn.Module): + vocab_size: int + hidden_size: int + type_vocab_size: int + max_length: int + num_encoder_layers: int + num_heads: int + head_size: int + intermediate_size: int + hidden_act: str = "gelu" + add_pooling_layer: bool = True + + @nn.compact + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + # Embedding + embeddings = FlaxPerformerEmbeddings( + self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings" + )(input_ids, token_type_ids, position_ids, attention_mask) + + # N stacked encoding layers + encoder = FlaxPerformerEncoder( + self.num_encoder_layers, + self.num_heads, + self.head_size, + self.intermediate_size, + hidden_act=self.hidden_act, + name="encoder", + )(embeddings, attention_mask) + + if not self.add_pooling_layer: + return encoder + + pooled = FlaxPerformerPooler(name="pooler")(encoder) + return encoder, pooled + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class FlaxPerformerModel(FlaxBertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + """ + + model_class = FlaxPerformerModule + config_class = BertConfig + base_model_prefix = "bert" + + @staticmethod + def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict: + jax_state = dict(pt_state) + + # Need to change some parameters name to match Flax names so that we don't have to fork any layer + for key, tensor in pt_state.items(): + # Key parts + key_parts = set(key.split(".")) + + # Every dense layer has "kernel" parameters instead of "weight" + if "dense.weight" in key: + del jax_state[key] + key = key.replace("weight", "kernel") + jax_state[key] = tensor + + # SelfAttention needs also to replace "weight" by "kernel" + if {"query", "key", "value"} & key_parts: + + # Flax SelfAttention decomposes the heads (num_head, size // num_heads) + if "bias" in key: + jax_state[key] = tensor.reshape((config.num_attention_heads, -1)) + elif "weight": + del jax_state[key] + key = key.replace("weight", "kernel") + tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1)) + jax_state[key] = tensor + + # SelfAttention output is not a separate layer, remove one nesting + if "attention.output.dense" in key: + del jax_state[key] + key = key.replace("attention.output.dense", "attention.self.out") + jax_state[key] = tensor + + # SelfAttention output is not a separate layer, remove nesting on layer norm + if "attention.output.LayerNorm" in key: + del jax_state[key] + key = key.replace("attention.output.LayerNorm", "attention.LayerNorm") + jax_state[key] = tensor + + # There are some transposed parameters w.r.t their PyTorch counterpart + if "intermediate.dense.kernel" in key or "output.dense.kernel" in key: + jax_state[key] = tensor.T + + # Self Attention output projection needs to be transposed + if "out.kernel" in key: + jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose( + 1, 2, 0 + ) + + # Pooler needs to transpose its kernel + if "pooler.dense.kernel" in key: + jax_state[key] = tensor.T + + # Handle LayerNorm conversion + if "LayerNorm" in key: + del jax_state[key] + + # Replace LayerNorm by layer_norm + new_key = key.replace("LayerNorm", "layer_norm") + + if "weight" in key: + new_key = new_key.replace("weight", "gamma") + elif "bias" in key: + new_key = new_key.replace("bias", "beta") + + jax_state[new_key] = tensor + + return jax_state + + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxPerformerModule( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + type_vocab_size=config.type_vocab_size, + max_length=config.max_position_embeddings, + num_encoder_layers=config.num_hidden_layers, + num_heads=config.num_attention_heads, + head_size=config.hidden_size, + intermediate_size=config.intermediate_size, + dropout_rate=config.hidden_dropout_prob, + hidden_act=config.hidden_act, + ) + + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + @property + def module(self) -> nn.Module: + return self._module + + def __call__( + self, input_ids, token_type_ids=None, position_ids=None, dropout_rng: PRNGKey = None, attention_mask=None + ): + + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + rng=rngs, + ) + + +class FlaxPerformerForMaskedLM(FlaxBertPreTrainedModel): + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxPerformerForMaskedLMModule( + vocab_size=config.vocab_size, + type_vocab_size=config.type_vocab_size, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + head_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_encoder_layers=config.num_hidden_layers, + max_length=config.max_position_embeddings, + hidden_act=config.hidden_act, + **kwargs, + ) + + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + train: bool = False, + dropout_rng: PRNGKey = None, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + +class FlaxPerformerForMaskedLMModule(nn.Module): + vocab_size: int + hidden_size: int + intermediate_size: int + head_size: int + num_heads: int + num_encoder_layers: int + type_vocab_size: int + max_length: int + hidden_act: str + dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + encoder = FlaxPerformerModule( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + type_vocab_size=self.type_vocab_size, + max_length=self.max_length, + num_encoder_layers=self.num_encoder_layers, + num_heads=self.num_heads, + head_size=self.hidden_size, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + add_pooling_layer=False, + name="bert", + )(input_ids, attention_mask, token_type_ids, position_ids) + + # Compute the prediction scores + encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic) + logits = FlaxBertOnlyMLMHead( + vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype + )(encoder) + + return (logits,) diff --git a/examples/research_projects/performer/modeling_flax_performer_utils.py b/examples/research_projects/performer/modeling_flax_performer_utils.py new file mode 100644 index 0000000000..abd42ec3d9 --- /dev/null +++ b/examples/research_projects/performer/modeling_flax_performer_utils.py @@ -0,0 +1,660 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +IMPORTANT: + +This code was copied from +https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py on +6/11/2020. This is very new code, so it might be prone to change soon -> make sure to check the original code and +update accordingly + +Core Fast Attention Module for Flax. Implementation of the approximate fast softmax and generalized attention mechanism +leveraging structured random feature maps [RFM] techniques and low rank decomposition of the attention matrix. +""" +# pylint: disable=invalid-name, missing-function-docstring, line-too-long + +import abc +import functools +from collections.abc import Iterable # pylint: disable=g-importing-member + +import numpy as onp +from absl import logging + +import jax +import jax.numpy as jnp +from jax import lax, random + + +def nonnegative_softmax_kernel_feature_creator( + data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True, eps=0.0001 +): + """ + Constructs nonnegative kernel features for fast softmax attention + + Args: + data: input for which features are computes + projection_matrix: random matrix used to compute features + attention_dims_t: tuple of attention dimensions + batch_dims_t: tuple of batch dimensions + precision: precision parameter + is_query: predicate indicating whether input data corresponds to queries or + keys + normalize_data: predicate indicating whether data should be normalized, + eps: numerical stabilizer + + Returns: + Random features for fast softmax attention. + """ + del attention_dims_t + if normalize_data: + # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where + # w_norm = w * data_normalizer for w in {q,k}. + data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) + else: + data_normalizer = 1.0 + ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) + data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape + data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix + + data_dash = lax.dot_general( + data_normalizer * data, + data_thick_random_matrix, + (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), + precision=precision, + ) + + diag_data = jnp.square(data) + diag_data = jnp.sum(diag_data, axis=data.ndim - 1) + diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer + diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) + + if is_query: + last_dims_t = (len(data_dash.shape) - 1,) + data_dash = ratio * ( + jnp.exp(data_dash - diag_data - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps + ) + else: + data_dash = ratio * (jnp.exp(data_dash - diag_data - jnp.max(data_dash)) + eps) + + return data_dash + + +def sincos_softmax_kernel_feature_creator( + data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data=True +): + """ + Constructs kernel sin-cos features for fast softmax attention + + Args: + data: input for which features are computes + projection_matrix: random matrix used to compute features + attention_dims_t: tuple of attention dimensions + batch_dims_t: tuple of batch dimensions + precision: precision parameter + normalize_data: predicate indicating whether data should be normalized + + Returns: + Random features for fast softmax attention. + """ + if normalize_data: + # We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) * + # exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}. + data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) + else: + data_normalizer = 1.0 + ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) + data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape + data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix + + data_dash = lax.dot_general( + data_normalizer * data, + data_thick_random_matrix, + (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), + precision=precision, + ) + data_dash_cos = ratio * jnp.cos(data_dash) + data_dash_sin = ratio * jnp.sin(data_dash) + data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1) + + # Constructing D_data and data^{'} + diag_data = jnp.square(data) + diag_data = jnp.sum(diag_data, axis=data.ndim - 1) + diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer + diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) + # Additional renormalization for numerical stability + data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True) + diag_data -= data_renormalizer + diag_data = jnp.exp(diag_data) + data_prime = data_dash * diag_data + return data_prime + + +def generalized_kernel_feature_creator( + data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data +): + """ + Constructs kernel features for fast generalized attention + + Args: + data: input for which features are computes + projection_matrix: matrix used to compute features + batch_dims_t: tuple of batch dimensions + precision: precision parameter + kernel_fn: kernel function used + kernel_epsilon: additive positive term added to every feature for numerical + stability + normalize_data: predicate indicating whether data should be normalized + + Returns: + Random features for fast generalized attention. + """ + if normalize_data: + data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) + else: + data_normalizer = 1.0 + if projection_matrix is None: + return kernel_fn(data_normalizer * data) + kernel_epsilon + else: + data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape + data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix + data_dash = lax.dot_general( + data_normalizer * data, + data_thick_random_matrix, + (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), + precision=precision, + ) + data_prime = kernel_fn(data_dash) + kernel_epsilon + return data_prime + + +def make_fast_softmax_attention( + qkv_dim, + renormalize_attention=True, + numerical_stabilizer=0.000001, + nb_features=256, + ortho_features=True, + ortho_scaling=0.0, + redraw_features=True, + unidirectional=False, + nonnegative_features=True, + lax_scan_unroll=1, +): + """Construct a fast softmax attention method.""" + logging.info( + "Fast softmax attention: %s features and orthogonal=%s, renormalize=%s", + nb_features, + ortho_features, + renormalize_attention, + ) + if ortho_features: + matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=ortho_scaling) + else: + matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim) + if nonnegative_features: + + def kernel_feature_creator( + data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True + ): + return nonnegative_softmax_kernel_feature_creator( + data, + projection_matrix, + attention_dims_t, + batch_dims_t, + precision, + is_query, + normalize_data, + numerical_stabilizer, + ) + + else: + + def kernel_feature_creator( + data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True + ): + del is_query + return sincos_softmax_kernel_feature_creator( + data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data + ) + + attention_fn = FastAttentionviaLowRankDecomposition( + matrix_creator, + kernel_feature_creator, + renormalize_attention=renormalize_attention, + numerical_stabilizer=numerical_stabilizer, + redraw_features=redraw_features, + unidirectional=unidirectional, + lax_scan_unroll=lax_scan_unroll, + ).dot_product_attention + return attention_fn + + +def make_fast_generalized_attention( + qkv_dim, + renormalize_attention=True, + numerical_stabilizer=0.0, + nb_features=256, + features_type="deterministic", + kernel_fn=jax.nn.relu, + kernel_epsilon=0.001, + redraw_features=False, + unidirectional=False, + lax_scan_unroll=1, +): + """Construct a fast generalized attention menthod.""" + logging.info("Fast generalized attention.: %s features and renormalize=%s", nb_features, renormalize_attention) + if features_type == "ortho": + matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False) + elif features_type == "iid": + matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim) + elif features_type == "deterministic": + matrix_creator = None + else: + raise ValueError("Unknown feature value type") + + def kernel_feature_creator( + data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=False + ): + del attention_dims_t + del is_query + return generalized_kernel_feature_creator( + data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data + ) + + attention_fn = FastAttentionviaLowRankDecomposition( + matrix_creator, + kernel_feature_creator, + renormalize_attention=renormalize_attention, + numerical_stabilizer=numerical_stabilizer, + redraw_features=redraw_features, + unidirectional=unidirectional, + lax_scan_unroll=lax_scan_unroll, + ).dot_product_attention + return attention_fn + + +class RandomMatrix(object): + r""" + Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D + random arrays. + """ + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def get_2d_array(self): + raise NotImplementedError("Abstract method") + + +class GaussianUnstructuredRandomMatrix(RandomMatrix): + def __init__(self, nb_rows, nb_columns, key): + self.nb_rows = nb_rows + self.nb_columns = nb_columns + self.key = key + + def get_2d_array(self): + return random.normal(self.key, (self.nb_rows, self.nb_columns)) + + +class GaussianOrthogonalRandomMatrix(RandomMatrix): + r""" + Class providing a method to create Gaussian orthogonal matrix. Class is responsible for constructing 2D Gaussian + orthogonal arrays. + """ + + def __init__(self, nb_rows, nb_columns, key, scaling=0): + self.nb_rows = nb_rows + self.nb_columns = nb_columns + self.key = key + self.scaling = scaling + + def get_2d_array(self): + nb_full_blocks = int(self.nb_rows / self.nb_columns) + block_list = [] + rng = self.key + for _ in range(nb_full_blocks): + rng, rng_input = jax.random.split(rng) + unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns)) + q, _ = jnp.linalg.qr(unstructured_block) + q = jnp.transpose(q) + block_list.append(q) + remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns + if remaining_rows > 0: + rng, rng_input = jax.random.split(rng) + unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns)) + q, _ = jnp.linalg.qr(unstructured_block) + q = jnp.transpose(q) + block_list.append(q[0:remaining_rows]) + final_matrix = jnp.vstack(block_list) + + if self.scaling == 0: + multiplier = jnp.linalg.norm(random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1) + elif self.scaling == 1: + multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows)) + else: + raise ValueError("Scaling must be one of {0, 1}. Was %s" % self._scaling) + + return jnp.matmul(jnp.diag(multiplier), final_matrix) + + +class FastAttention(object): + r""" + Abstract class providing a method for fast attention. Class is responsible for providing a method + for fast approximate attention. + """ + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def dot_product_attention( + self, + query, + key, + value, + dtype=jnp.float32, + bias=None, + axis=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0.0, + deterministic=False, + precision=None, + ): + """ + Computes dot-product attention given query, key, and value. This is the core function for applying fast + approximate dot-product attention. It calculates the attention weights given query and key and combines the + values using the attention weights. This function supports multi-dimensional inputs + + Args: + query: queries for calculating attention with shape of [batch_size, dim1, + dim2, ..., dimN, num_heads, mem_channels]. + key: keys for calculating attention with shape of [batch_size, dim1, dim2, + ..., dimN, num_heads, mem_channels]. + value: values to be used in attention with shape of [batch_size, dim1, + dim2,..., dimN, num_heads, value_channels]. + dtype: the dtype of the computation (default: float32) + bias: bias for the attention weights. This can be used for incorporating + autoregressive mask, padding mask, proximity bias. + axis: axises over which the attention is applied. + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rng: JAX PRNGKey: to be used for dropout. + dropout_rate: dropout rate. + deterministic: bool, deterministic or not (to apply dropout). + precision: numerical precision of the computation see `jax.lax.Precision` + for details + + Returns: + Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels]. + """ + raise NotImplementedError("Abstract method") + + +def _numerator(z_slice_shape, precision, unroll=1): + def fwd(qs, ks, vs): + def body(p, qkv): + (q, k, v) = qkv + p += jnp.einsum("...m,...d->...md", k, v, precision=precision) + X_slice = jnp.einsum("...m,...md->...d", q, p, precision=precision) + return p, X_slice + + init_value = jnp.zeros(z_slice_shape) + p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll) + return W, (p, qs, ks, vs) + + def bwd(pqkv, W_ct): + def body(carry, qkv_xct): + p, p_ct = carry + q, k, v, x_ct = qkv_xct + q_ct = jnp.einsum("...d,...md->...m", x_ct, p, precision=precision) + p_ct += jnp.einsum("...d,...m->...md", x_ct, q, precision=precision) + k_ct = jnp.einsum("...md,...d->...m", p_ct, v, precision=precision) + v_ct = jnp.einsum("...md,...m->...d", p_ct, k, precision=precision) + p -= jnp.einsum("...m,...d->...md", k, v, precision=precision) + return (p, p_ct), (q_ct, k_ct, v_ct) + + p, qs, ks, vs = pqkv + _, (qs_ct, ks_ct, vs_ct) = lax.scan( + body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct), reverse=True, unroll=unroll + ) + return qs_ct, ks_ct, vs_ct + + @jax.custom_vjp + def _numerator_impl(qs, ks, vs): + W, _ = fwd(qs, ks, vs) + return W + + _numerator_impl.defvjp(fwd, bwd) + + return _numerator_impl + + +def _denominator(t_slice_shape, precision, unroll=1): + def fwd(qs, ks): + def body(p, qk): + q, k = qk + p += k + x = jnp.einsum("...m,...m->...", q, p, precision=precision) + return p, x + + p = jnp.zeros(t_slice_shape) + p, R = lax.scan(body, p, (qs, ks), unroll=unroll) + return R, (qs, ks, p) + + def bwd(qkp, R_ct): + def body(carry, qkx): + p, p_ct = carry + q, k, x_ct = qkx + q_ct = jnp.einsum("...,...m->...m", x_ct, p, precision=precision) + p_ct += jnp.einsum("...,...m->...m", x_ct, q, precision=precision) + k_ct = p_ct + p -= k + return (p, p_ct), (q_ct, k_ct) + + qs, ks, p = qkp + _, (qs_ct, ks_ct) = lax.scan(body, (p, jnp.zeros_like(p)), (qs, ks, R_ct), reverse=True, unroll=unroll) + return (qs_ct, ks_ct) + + @jax.custom_vjp + def _denominator_impl(qs, ks): + R, _ = fwd(qs, ks) + return R + + _denominator_impl.defvjp(fwd, bwd) + + return _denominator_impl + + +class FastAttentionviaLowRankDecomposition(FastAttention): + r""" + Class providing a method for fast attention via low rank decomposition. Class is responsible for providing a method + for fast dot-product attention with the use of low rank decomposition (e.g. with random + feature maps). + """ + + def __init__( + self, + matrix_creator, + kernel_feature_creator, + renormalize_attention, + numerical_stabilizer, + redraw_features, + unidirectional, + lax_scan_unroll=1, + ): # For optimal GPU performance, set to 16. + rng = random.PRNGKey(0) + self.matrix_creator = matrix_creator + self.projection_matrix = self.draw_weights(rng) + self.kernel_feature_creator = kernel_feature_creator + self.renormalize_attention = renormalize_attention + self.numerical_stabilizer = numerical_stabilizer + self.redraw_features = redraw_features + self.unidirectional = unidirectional + self.lax_scan_unroll = lax_scan_unroll + + def draw_weights(self, key): + if self.matrix_creator is None: + return None + matrixrng, _ = random.split(key) + projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array() + return projection_matrix + + def dot_product_attention( + self, + query, + key, + value, + dtype=jnp.float32, + bias=None, + axis=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0.0, + deterministic=False, + precision=None, + ): + + assert key.shape[:-1] == value.shape[:-1] + assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1] + if axis is None: + axis = tuple(range(1, key.ndim - 2)) + if not isinstance(axis, Iterable): + axis = (axis,) + assert key.ndim == query.ndim + assert key.ndim == value.ndim + for ax in axis: + if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): + raise ValueError("Attention axis must be between the batch " "axis and the last-two axes.") + n = key.ndim + + # Constructing projection tensor. + if self.redraw_features: + # TODO(kchoro): Get rid of the constant below. + query_seed = lax.convert_element_type(jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32) + rng = random.PRNGKey(query_seed) + self.projection_matrix = self.draw_weights(rng) + + # batch_dims is , num_heads> + batch_dims = tuple(onp.delete(range(n), axis + (n - 1,))) + # q & k -> (bs, , num_heads, , channels) + qk_perm = batch_dims + axis + (n - 1,) + k_extra_perm = axis + batch_dims + (n - 1,) + key_extra = key.transpose(k_extra_perm) + key = key.transpose(qk_perm) + query = query.transpose(qk_perm) + # v -> (bs, , num_heads, , channels) + v_perm = batch_dims + axis + (n - 1,) + value = value.transpose(v_perm) + batch_dims_t = tuple(range(len(batch_dims))) + attention_dims_t = tuple(range(len(batch_dims), len(batch_dims) + len(axis))) + + # Constructing tensors Q^{'} and K^{'}. + query_prime = self.kernel_feature_creator( + query, self.projection_matrix, attention_dims_t, batch_dims_t, precision, True + ) + key_prime = self.kernel_feature_creator( + key, self.projection_matrix, attention_dims_t, batch_dims_t, precision, False + ) + + if self.unidirectional: + index = attention_dims_t[0] + z_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) + (value.shape[-1],) + + numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll) + W = numerator_fn( + jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0) + ) + + # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V + W = jnp.moveaxis(W, 0, index) + + if not self.renormalize_attention: + # Unidirectional, not-normalized attention. + perm_inv = _invert_perm(qk_perm) + result = W.transpose(perm_inv) + return result + else: + # Unidirectional, normalized attention. + thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)]) + + index = attention_dims_t[0] + t_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) + denominator_fn = _denominator(t_slice_shape, precision, self.lax_scan_unroll) + R = denominator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0)) + + R = jnp.moveaxis(R, 0, index) + else: + contract_query = tuple(range(len(batch_dims) + len(axis), len(batch_dims) + len(axis) + 1)) + contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1)) + # Constructing Z = (K^{'})^{T}V + # Z (bs, , num_heads, channels_m, channels_v) + Z = lax.dot_general( + key_prime, + value, + ((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)), + precision=precision, + ) + # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V + # q (bs, , num_heads, , channels_m) + # Z (bs, , num_heads, channels_m, channels_v) + # W (bs, , num_heads, , channels_v) + W = lax.dot_general( + query_prime, Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), precision=precision + ) + if not self.renormalize_attention: + # Bidirectional, not-normalized attention. + perm_inv = _invert_perm(qk_perm) + result = W.transpose(perm_inv) + return result + else: + # Bidirectional, normalized attention. + thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)]) + contract_key = tuple(range(len(batch_dims), len(batch_dims) + len(axis))) + contract_thick_all_ones = tuple(range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim)) + # Construct T = (K^{'})^{T} 1_L + # k (bs, , num_heads, , channels) + T = lax.dot_general( + key_prime, + thick_all_ones, + ((contract_key, contract_thick_all_ones), (batch_dims_t, batch_dims_t)), + precision=precision, + ) + + # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L + # q_p (bs, , num_heads, , channs_m) + # T (bs, , num_heads, channels_m) + R = lax.dot_general( + query_prime, + T, + (((query_prime.ndim - 1,), (T.ndim - 1,)), (batch_dims_t, range(0, len(T.shape) - 1))), + precision=precision, + ) + + R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <= self.numerical_stabilizer) + R = jnp.reciprocal(R) + R = jnp.expand_dims(R, len(R.shape)) + # W (bs, , num_heads, , channels_v) + # R (bs, , num_heads, , extra_channel) + result = W * R + # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) + perm_inv = _invert_perm(qk_perm) + result = result.transpose(perm_inv) + return result + + +def _invert_perm(perm): + perm_inv = [0] * len(perm) + for i, j in enumerate(perm): + perm_inv[j] = i + return tuple(perm_inv) diff --git a/examples/research_projects/performer/run_mlm_performer.py b/examples/research_projects/performer/run_mlm_performer.py new file mode 100644 index 0000000000..056dd0f27f --- /dev/null +++ b/examples/research_projects/performer/run_mlm_performer.py @@ -0,0 +1,685 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a +text file or a dataset. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=masked-lm +""" +import logging +import os +import sys +from dataclasses import dataclass, field + +# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +from datasets import load_dataset +from tqdm import tqdm + +import jax +import jax.numpy as jnp +from flax import jax_utils +from flax.optim import Adam +from flax.training import common_utils +from flax.training.common_utils import get_metrics +from jax.nn import log_softmax +from modeling_flax_performer import FlaxPerformerForMaskedLM +from transformers import ( + MODEL_FOR_MASKED_LM_MAPPING, + AutoTokenizer, + BertConfig, + FlaxBertForMaskedLM, + HfArgumentParser, + PreTrainedTokenizerBase, + TensorType, + TrainingArguments, + is_tensorboard_available, + set_seed, +) + + +# Cache the result +has_tensorboard = is_tensorboard_available() +if has_tensorboard: + try: + from flax.metrics.tensorboard import SummaryWriter + except ImportError as ie: + has_tensorboard = False + print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}") + +else: + print( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class WandbArguments: + """ + Arguments for logging + """ + + wandb_user_name: Optional[str] = field( + default=None, + metadata={"help": "The WandB user name for potential logging. If left None, no logging"}, + ) + wandb_project_name: Optional[str] = field( + default="performer-experiments", + metadata={"help": "The WandB project name for potential logging"}, + ) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + performer: bool = field( + default=False, + metadata={"help": "Whether to use FAVOR+ attention"}, + ) + reinitialize: bool = field( + default=False, + metadata={"help": "Whether to use a blank model without pretraining"}, + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + train_ref_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input train ref data file for whole word masking in Chinese."}, + ) + validation_ref_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated. Default to the max input length of the model." + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + mlm_probability: float = field( + default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + + +# Adapted from transformers/data/data_collator.py +# Letting here for now, let's discuss where it should live +@dataclass +class FlaxDataCollatorForLanguageModeling: + """ + Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they + are not all of the same length. + + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + mlm (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the + inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for + non-masked tokens and the value to predict for the masked token. + mlm_probability (:obj:`float`, `optional`, defaults to 0.15): + The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`. + + .. note:: + + For best performance, this data collator should be used with a dataset having items that are dictionaries or + BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a + :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the + argument :obj:`return_special_tokens_mask=True`. + """ + + tokenizer: PreTrainedTokenizerBase + mlm: bool = True + mlm_probability: float = 0.15 + + def __post_init__(self): + if self.mlm and self.tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language modeling. " + "You should pass `mlm=False` to train on causal language modeling instead." + ) + + def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]: + # Handle dict or lists with proper padding and conversion to tensor. + batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY) + + # If special token mask has been preprocessed, pop it from the dict. + special_tokens_mask = batch.pop("special_tokens_mask", None) + if self.mlm: + batch["input_ids"], batch["labels"] = self.mask_tokens( + batch["input_ids"], special_tokens_mask=special_tokens_mask + ) + else: + labels = batch["input_ids"].copy() + if self.tokenizer.pad_token_id is not None: + labels[labels == self.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + return batch + + def mask_tokens( + self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray] + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + """ + labels = inputs.copy() + # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) + probability_matrix = np.full(labels.shape, self.mlm_probability) + special_tokens_mask = special_tokens_mask.astype("bool") + + probability_matrix[special_tokens_mask] = 0.0 + masked_indices = np.random.binomial(1, probability_matrix).astype("bool") + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices + inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool") + indices_random &= masked_indices & ~indices_replaced + + random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4") + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels + + +def create_learning_rate_scheduler( + factors="constant * linear_warmup * rsqrt_decay", + base_learning_rate=0.5, + warmup_steps=1000, + decay_factor=0.5, + steps_per_decay=20000, + steps_per_cycle=100000, +): + """Creates learning rate schedule. + Interprets factors in the factors string which can consist of: + * constant: interpreted as the constant value, + * linear_warmup: interpreted as linear warmup until warmup_steps, + * rsqrt_decay: divide by square root of max(step, warmup_steps) + * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) + * decay_every: Every k steps decay the learning rate by decay_factor. + * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. + Args: + factors: string, factors separated by "*" that defines the schedule. + base_learning_rate: float, the starting constant for the lr schedule. + warmup_steps: int, how many steps to warm up for in the warmup schedule. + decay_factor: float, the amount to decay the learning rate by. + steps_per_decay: int, how often to decay the learning rate. + steps_per_cycle: int, steps per cycle when using cosine decay. + Returns: + a function learning_rate(step): float -> {"learning_rate": float}, the + step-dependent lr. + """ + factors = [n.strip() for n in factors.split("*")] + + def step_fn(step): + """Step to learning rate function.""" + ret = 1.0 + for name in factors: + if name == "constant": + ret *= base_learning_rate + elif name == "linear_warmup": + ret *= jnp.minimum(1.0, step / warmup_steps) + elif name == "rsqrt_decay": + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == "rsqrt_normalized_decay": + ret *= jnp.sqrt(warmup_steps) + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == "decay_every": + ret *= decay_factor ** (step // steps_per_decay) + elif name == "cosine_decay": + progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) + ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) + else: + raise ValueError("Unknown factor %s." % name) + return jnp.asarray(ret, dtype=jnp.float32) + + return step_fn + + +def compute_metrics(logits, labels, weights, label_smoothing=0.0): + """Compute summary metrics.""" + loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing) + acc, _ = accuracy(logits, labels, weights) + metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer} + metrics = jax.lax.psum(metrics, axis_name="batch") + return metrics + + +def accuracy(logits, targets, weights=None): + """Compute weighted accuracy for log probs and targets. + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: None or array of shape [batch, length] + Returns: + Tuple of scalar loss and batch normalizing factor. + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) + ) + + loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) + loss *= weights + + return loss.sum(), weights.sum() + + +def cross_entropy(logits, targets, weights=None, label_smoothing=0.0): + """Compute cross entropy and entropy for log probs and targets. + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: None or array of shape [batch, length] + label_smoothing: label smoothing constant, used to determine the on and off values. + Returns: + Tuple of scalar loss and batch normalizing factor. + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) + ) + + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1) + loss = loss - normalizing_constant + + if weights is not None: + loss = loss * weights + normalizing_factor = weights.sum() + else: + normalizing_factor = np.prod(targets.shape) + + return loss.sum(), normalizing_factor + + +def training_step(optimizer, batch, dropout_rng): + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) + + def loss_fn(params): + targets = batch.pop("labels") + + # Hide away tokens which doesn't participate in the optimization + token_mask = jnp.where(targets > 0, 1.0, 0.0) + + logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss, weight_sum = cross_entropy(logits, targets, token_mask) + return loss / weight_sum + + step = optimizer.state.step + lr = lr_scheduler_fn(step) + grad_fn = jax.value_and_grad(loss_fn) + loss, grad = grad_fn(optimizer.target) + grad = jax.lax.pmean(grad, "batch") + optimizer = optimizer.apply_gradient(grad, learning_rate=lr) + + return loss, optimizer, new_dropout_rng + + +def eval_step(params, batch): + """ + Calculate evaluation metrics on a batch. + """ + targets = batch.pop("labels") + + # Hide away tokens which doesn't participate in the optimization + token_mask = jnp.where(targets > 0, 1.0, 0.0) + logits = model(**batch, params=params, train=False)[0] + + return compute_metrics(logits, targets, token_mask) + + +def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: + nb_samples = len(samples_idx) + samples_to_remove = nb_samples % batch_size + + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = nb_samples // batch_size + batch_idx = np.split(samples_idx, sections_split) + return batch_idx + + +if __name__ == "__main__": + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, WandbArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args, wandb_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args, wandb_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + level="NOTSET", + datefmt="[%X]", + ) + + # Log on each process the small summary: + logger = logging.getLogger(__name__) + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + datasets = load_dataset(extension, data_files=data_files) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + rng = jax.random.PRNGKey(training_args.seed) + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + + config = BertConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + lm_class = FlaxPerformerForMaskedLM if model_args.performer else FlaxBertForMaskedLM + if model_args.reinitialize: + model = lm_class(config=BertConfig.from_pretrained(model_args.model_name_or_path)) + else: + model = lm_class.from_pretrained( + model_args.model_name_or_path, + dtype=jnp.float32, + input_shape=(training_args.train_batch_size, config.max_position_embeddings), + seed=training_args.seed, + dropout_rate=0.1, + ) + + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + # Preprocessing the datasets. + # First we tokenize all the texts. + if training_args.do_train: + column_names = datasets["train"].column_names + else: + column_names = datasets["validation"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + padding = "max_length" if data_args.pad_to_max_length else False + + def tokenize_function(examples): + # Remove empty lines + examples = [line for line in examples if len(line) > 0 and not line.isspace()] + return tokenizer( + examples, + return_special_tokens_mask=True, + padding=padding, + truncation=True, + max_length=data_args.max_seq_length, + ) + + tokenized_datasets = datasets.map( + tokenize_function, + input_columns=[text_column_name], + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + # Enable tensorboard only on the master node + if has_tensorboard and jax.host_id() == 0: + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) + + # Data collator + # This one will take care of randomly masking the tokens. + data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) + + # Setup optimizer + optimizer = Adam( + learning_rate=training_args.learning_rate, + weight_decay=training_args.weight_decay, + beta1=training_args.adam_beta1, + beta2=training_args.adam_beta2, + ).create(model.params) + + # Create learning rate scheduler + lr_scheduler_fn = create_learning_rate_scheduler( + base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1) + ) + + # Create parallel version of the training and evaluation steps + p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,)) + p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) + + # Replicate the optimizer on each device + optimizer = jax_utils.replicate(optimizer) + + # Store some constant + nb_epochs = int(training_args.num_train_epochs) + batch_size = int(training_args.train_batch_size) + eval_batch_size = int(training_args.eval_batch_size) + + if wandb_args.wandb_user_name is not None: + import wandb + + wandb.init(project=wandb_args.wandb_project_name, entity=wandb_args.wandb_user_name) + + epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0) + for epoch in epochs: + + # ======================== Training ================================ + # Create sampling rng + rng, training_rng, eval_rng = jax.random.split(rng, 3) + + # Generate an epoch by shuffling sampling indices from the train dataset + nb_training_samples = len(tokenized_datasets["train"]) + training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples)) + training_batch_idx = generate_batch_splits(training_samples_idx, batch_size) + + # Gather the indexes for creating the batch and do a training step + for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1): + samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples, pad_to_multiple_of=16) + + # Model forward + model_inputs = common_utils.shard(model_inputs.data) + loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs) + + if wandb_args.wandb_user_name is not None: + wandb.log({"Training loss": np.array(loss).mean()}) + + epochs.write(f"Loss: {loss}") + + # ======================== Evaluating ============================== + nb_eval_samples = len(tokenized_datasets["validation"]) + eval_samples_idx = jnp.arange(nb_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + + eval_metrics = [] + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples, pad_to_multiple_of=16) + + # Model forward + model_inputs = common_utils.shard(model_inputs.data) + metrics = p_eval_step(optimizer.target, model_inputs) + eval_metrics.append(metrics) + + eval_metrics_np = get_metrics(eval_metrics) + eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np) + eval_normalizer = eval_metrics_np.pop("normalizer") + eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np) + + # Update progress bar + epochs.desc = ( + f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})" + ) + + if wandb_args.wandb_user_name is not None: + wandb.log({"Eval loss": np.array(eval_summary["loss"]).mean()}) + + # Save metrics + if has_tensorboard and jax.host_id() == 0: + for name, value in eval_summary.items(): + summary_writer.scalar(name, value, epoch) diff --git a/examples/research_projects/performer/sanity_script.sh b/examples/research_projects/performer/sanity_script.sh new file mode 100755 index 0000000000..b96cd7e643 --- /dev/null +++ b/examples/research_projects/performer/sanity_script.sh @@ -0,0 +1 @@ +TOKENIZERS_PARALLELISM=true python run_mlm_performer.py --output_dir experiments --dataset_name wikipedia --dataset_config_name 20200501.simple --model_name_or_path bert-base-cased --tokenizer_name bert-base-cased --do_train --overwrite_output_dir --per_device_train_batch_size 4 --learning_rate 5e-4 --warmup_steps 100 --num_train_epochs 3 --performer \ No newline at end of file