Adding performer fine-tuning research exampke (#9239)
* added run_mlm_performer.py research example * make styke * make styke * Added a README !
This commit is contained in:
25
examples/research_projects/performer/README.md
Normal file
25
examples/research_projects/performer/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# Performer fine-tuning
|
||||
|
||||
Example authors: @TevenLeScao, @Patrickvonplaten
|
||||
|
||||
Paper authors: Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, David Belanger, Lucy Colwell, Adrian Weller
|
||||
|
||||
## Requirements
|
||||
|
||||
`datasets`, `flax` and `jax`. `wandb` integration is built-in if you want to use it.
|
||||
|
||||
## Examples
|
||||
|
||||
`sanity_script.sh` will launch performer fine-tuning from the bert-base-cased checkpoint on the Simple Wikipedia dataset (a small, easy-language English Wikipedia) from `datasets`.
|
||||
`full_script.sh` will launch performer fine-tuning from the bert-large-cased checkpoint on the English Wikipedia dataset from `datasets`.
|
||||
|
||||
Here are a few key arguments:
|
||||
- Remove the `--performer` argument to use a standard Bert model.
|
||||
|
||||
- Add `--reinitialize` to start from a blank model rather than a Bert checkpoint.
|
||||
|
||||
- You may change the Bert size by passing a different [checkpoint](https://huggingface.co/transformers/pretrained_models.html) to the `--model_name_or_path` argument.
|
||||
|
||||
- Passing your user name to the `--wandb_user_name` argument will trigger weights and biases logging.
|
||||
|
||||
- You can choose a dataset with `--dataset_name` and `--dataset_config`. Our [viewer](https://huggingface.co/datasets/viewer/) will help you find what you need.
|
||||
1
examples/research_projects/performer/full_script.sh
Executable file
1
examples/research_projects/performer/full_script.sh
Executable file
@@ -0,0 +1 @@
|
||||
TOKENIZERS_PARALLELISM=true python run_mlm_performer.py --output_dir experiments --dataset_name wikipedia --dataset_config_name 20200501.en --model_name_or_path bert-large-cased --tokenizer_name bert-large-cased --do_train --overwrite_output_dir --per_device_train_batch_size 4 --learning_rate 5e-4 --warmup_steps 100 --num_train_epochs 3 --performer
|
||||
553
examples/research_projects/performer/modeling_flax_performer.py
Normal file
553
examples/research_projects/performer/modeling_flax_performer.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.random import PRNGKey
|
||||
from modeling_flax_performer_utils import make_fast_softmax_attention
|
||||
from transformers.file_utils import add_start_docstrings
|
||||
from transformers.modeling_flax_utils import ACT2FN
|
||||
from transformers.models.bert.configuration_bert import BertConfig
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertOnlyMLMHead, FlaxBertPreTrainedModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "BertConfig"
|
||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||
|
||||
BERT_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||
weights.
|
||||
"""
|
||||
|
||||
BERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 0 corresponds to a `sentence A` token,
|
||||
- 1 corresponds to a `sentence B` token.
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||
config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class FlaxPerformerLayerNorm(nn.Module):
|
||||
"""
|
||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
||||
"""
|
||||
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
bias: bool = True # If True, bias (beta) is added.
|
||||
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
|
||||
# (also e.g. nn.relu), this can be disabled since the scaling will be
|
||||
# done by the next layer.
|
||||
bias_init: jnp.ndarray = nn.initializers.zeros
|
||||
scale_init: jnp.ndarray = nn.initializers.ones
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
||||
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
|
||||
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
|
||||
|
||||
Args:
|
||||
x: the inputs
|
||||
|
||||
Returns:
|
||||
Normalized inputs (the same shape as inputs).
|
||||
"""
|
||||
features = x.shape[-1]
|
||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
|
||||
y = (x - mean) * mul
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
|
||||
return y
|
||||
|
||||
|
||||
class FlaxPerformerEmbedding(nn.Module):
|
||||
"""
|
||||
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
|
||||
use 'weight'
|
||||
"""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
||||
return jnp.take(embedding, inputs, axis=0)
|
||||
|
||||
|
||||
class FlaxPerformerEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
# Embed
|
||||
w_emb = FlaxPerformerEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
|
||||
jnp.atleast_2d(input_ids.astype("i4"))
|
||||
)
|
||||
p_emb = FlaxPerformerEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
|
||||
jnp.atleast_2d(position_ids.astype("i4"))
|
||||
)
|
||||
t_emb = FlaxPerformerEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
|
||||
jnp.atleast_2d(token_type_ids.astype("i4"))
|
||||
)
|
||||
|
||||
# Sum all embeddings
|
||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||
|
||||
# Layer Norm
|
||||
layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(summed_emb)
|
||||
|
||||
return layer_norm
|
||||
|
||||
|
||||
class FlaxPerformerAttention(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
single_head_dim = self.head_size // self.num_heads
|
||||
fast_softmax_attention = make_fast_softmax_attention(qkv_dim=single_head_dim)
|
||||
self_att = nn.attention.SelfAttention(
|
||||
num_heads=self.num_heads, qkv_features=self.head_size, name="self", attention_fn=fast_softmax_attention
|
||||
)(hidden_state, attention_mask)
|
||||
|
||||
layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(self_att + hidden_state)
|
||||
return layer_norm
|
||||
|
||||
|
||||
class FlaxPerformerIntermediate(nn.Module):
|
||||
output_size: int
|
||||
hidden_act: str = "gelu"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
|
||||
return ACT2FN[self.hidden_act](dense)
|
||||
|
||||
|
||||
class FlaxPerformerOutput(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output):
|
||||
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
|
||||
hidden_state = FlaxPerformerLayerNorm(name="layer_norm")(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxPerformerLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
attention = FlaxPerformerAttention(self.num_heads, self.head_size, name="attention")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
intermediate = FlaxPerformerIntermediate(
|
||||
self.intermediate_size, name="intermediate", hidden_act=self.hidden_act
|
||||
)(attention)
|
||||
output = FlaxPerformerOutput(name="output")(intermediate, attention)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FlaxPerformerLayerCollection(nn.Module):
|
||||
"""
|
||||
Stores N BertLayer(s)
|
||||
"""
|
||||
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, attention_mask):
|
||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||
|
||||
# Initialize input / output
|
||||
input_i = inputs
|
||||
|
||||
# Forward over all encoders
|
||||
for i in range(self.num_layers):
|
||||
layer = FlaxPerformerLayer(
|
||||
self.num_heads, self.head_size, self.intermediate_size, hidden_act=self.hidden_act, name=f"{i}"
|
||||
)
|
||||
input_i = layer(input_i, attention_mask)
|
||||
return input_i
|
||||
|
||||
|
||||
class FlaxPerformerEncoder(nn.Module):
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
layer = FlaxPerformerLayerCollection(
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
name="layer",
|
||||
hidden_act=self.hidden_act,
|
||||
)(hidden_state, attention_mask)
|
||||
return layer
|
||||
|
||||
|
||||
class FlaxPerformerPooler(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
|
||||
return jax.lax.tanh(out)
|
||||
|
||||
|
||||
class FlaxPerformerModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
# Embedding
|
||||
embeddings = FlaxPerformerEmbeddings(
|
||||
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxPerformerEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
name="encoder",
|
||||
)(embeddings, attention_mask)
|
||||
|
||||
if not self.add_pooling_layer:
|
||||
return encoder
|
||||
|
||||
pooled = FlaxPerformerPooler(name="pooler")(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxPerformerModel(FlaxBertPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
|
||||
model_class = FlaxPerformerModule
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "bert"
|
||||
|
||||
@staticmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
||||
jax_state = dict(pt_state)
|
||||
|
||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||
for key, tensor in pt_state.items():
|
||||
# Key parts
|
||||
key_parts = set(key.split("."))
|
||||
|
||||
# Every dense layer has "kernel" parameters instead of "weight"
|
||||
if "dense.weight" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||
if "bias" in key:
|
||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||
elif "weight":
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove one nesting
|
||||
if "attention.output.dense" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.dense", "attention.self.out")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
||||
if "attention.output.LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Self Attention output projection needs to be transposed
|
||||
if "out.kernel" in key:
|
||||
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
||||
1, 2, 0
|
||||
)
|
||||
|
||||
# Pooler needs to transpose its kernel
|
||||
if "pooler.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Handle LayerNorm conversion
|
||||
if "LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
|
||||
# Replace LayerNorm by layer_norm
|
||||
new_key = key.replace("LayerNorm", "layer_norm")
|
||||
|
||||
if "weight" in key:
|
||||
new_key = new_key.replace("weight", "gamma")
|
||||
elif "bias" in key:
|
||||
new_key = new_key.replace("bias", "beta")
|
||||
|
||||
jax_state[new_key] = tensor
|
||||
|
||||
return jax_state
|
||||
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxPerformerModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
max_length=config.max_position_embeddings,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
return self._module
|
||||
|
||||
def __call__(
|
||||
self, input_ids, token_type_ids=None, position_ids=None, dropout_rng: PRNGKey = None, attention_mask=None
|
||||
):
|
||||
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
rng=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxPerformerForMaskedLM(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxPerformerForMaskedLMModule(
|
||||
vocab_size=config.vocab_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
head_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
max_length=config.max_position_embeddings,
|
||||
hidden_act=config.hidden_act,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
train: bool = False,
|
||||
dropout_rng: PRNGKey = None,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxPerformerForMaskedLMModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
head_size: int
|
||||
num_heads: int
|
||||
num_encoder_layers: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
hidden_act: str
|
||||
dropout_rate: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
encoder = FlaxPerformerModule(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
max_length=self.max_length,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
add_pooling_layer=False,
|
||||
name="bert",
|
||||
)(input_ids, attention_mask, token_type_ids, position_ids)
|
||||
|
||||
# Compute the prediction scores
|
||||
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
||||
logits = FlaxBertOnlyMLMHead(
|
||||
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
|
||||
)(encoder)
|
||||
|
||||
return (logits,)
|
||||
@@ -0,0 +1,660 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
IMPORTANT:
|
||||
|
||||
This code was copied from
|
||||
https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py on
|
||||
6/11/2020. This is very new code, so it might be prone to change soon -> make sure to check the original code and
|
||||
update accordingly
|
||||
|
||||
Core Fast Attention Module for Flax. Implementation of the approximate fast softmax and generalized attention mechanism
|
||||
leveraging structured random feature maps [RFM] techniques and low rank decomposition of the attention matrix.
|
||||
"""
|
||||
# pylint: disable=invalid-name, missing-function-docstring, line-too-long
|
||||
|
||||
import abc
|
||||
import functools
|
||||
from collections.abc import Iterable # pylint: disable=g-importing-member
|
||||
|
||||
import numpy as onp
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax, random
|
||||
|
||||
|
||||
def nonnegative_softmax_kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True, eps=0.0001
|
||||
):
|
||||
"""
|
||||
Constructs nonnegative kernel features for fast softmax attention
|
||||
|
||||
Args:
|
||||
data: input for which features are computes
|
||||
projection_matrix: random matrix used to compute features
|
||||
attention_dims_t: tuple of attention dimensions
|
||||
batch_dims_t: tuple of batch dimensions
|
||||
precision: precision parameter
|
||||
is_query: predicate indicating whether input data corresponds to queries or
|
||||
keys
|
||||
normalize_data: predicate indicating whether data should be normalized,
|
||||
eps: numerical stabilizer
|
||||
|
||||
Returns:
|
||||
Random features for fast softmax attention.
|
||||
"""
|
||||
del attention_dims_t
|
||||
if normalize_data:
|
||||
# We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
|
||||
# w_norm = w * data_normalizer for w in {q,k}.
|
||||
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
||||
else:
|
||||
data_normalizer = 1.0
|
||||
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
|
||||
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
|
||||
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
||||
|
||||
data_dash = lax.dot_general(
|
||||
data_normalizer * data,
|
||||
data_thick_random_matrix,
|
||||
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
diag_data = jnp.square(data)
|
||||
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
|
||||
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
|
||||
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
|
||||
|
||||
if is_query:
|
||||
last_dims_t = (len(data_dash.shape) - 1,)
|
||||
data_dash = ratio * (
|
||||
jnp.exp(data_dash - diag_data - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps
|
||||
)
|
||||
else:
|
||||
data_dash = ratio * (jnp.exp(data_dash - diag_data - jnp.max(data_dash)) + eps)
|
||||
|
||||
return data_dash
|
||||
|
||||
|
||||
def sincos_softmax_kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data=True
|
||||
):
|
||||
"""
|
||||
Constructs kernel sin-cos features for fast softmax attention
|
||||
|
||||
Args:
|
||||
data: input for which features are computes
|
||||
projection_matrix: random matrix used to compute features
|
||||
attention_dims_t: tuple of attention dimensions
|
||||
batch_dims_t: tuple of batch dimensions
|
||||
precision: precision parameter
|
||||
normalize_data: predicate indicating whether data should be normalized
|
||||
|
||||
Returns:
|
||||
Random features for fast softmax attention.
|
||||
"""
|
||||
if normalize_data:
|
||||
# We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
|
||||
# exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
|
||||
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
||||
else:
|
||||
data_normalizer = 1.0
|
||||
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
|
||||
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
|
||||
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
||||
|
||||
data_dash = lax.dot_general(
|
||||
data_normalizer * data,
|
||||
data_thick_random_matrix,
|
||||
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
data_dash_cos = ratio * jnp.cos(data_dash)
|
||||
data_dash_sin = ratio * jnp.sin(data_dash)
|
||||
data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)
|
||||
|
||||
# Constructing D_data and data^{'}
|
||||
diag_data = jnp.square(data)
|
||||
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
|
||||
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
|
||||
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
|
||||
# Additional renormalization for numerical stability
|
||||
data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True)
|
||||
diag_data -= data_renormalizer
|
||||
diag_data = jnp.exp(diag_data)
|
||||
data_prime = data_dash * diag_data
|
||||
return data_prime
|
||||
|
||||
|
||||
def generalized_kernel_feature_creator(
|
||||
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data
|
||||
):
|
||||
"""
|
||||
Constructs kernel features for fast generalized attention
|
||||
|
||||
Args:
|
||||
data: input for which features are computes
|
||||
projection_matrix: matrix used to compute features
|
||||
batch_dims_t: tuple of batch dimensions
|
||||
precision: precision parameter
|
||||
kernel_fn: kernel function used
|
||||
kernel_epsilon: additive positive term added to every feature for numerical
|
||||
stability
|
||||
normalize_data: predicate indicating whether data should be normalized
|
||||
|
||||
Returns:
|
||||
Random features for fast generalized attention.
|
||||
"""
|
||||
if normalize_data:
|
||||
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
||||
else:
|
||||
data_normalizer = 1.0
|
||||
if projection_matrix is None:
|
||||
return kernel_fn(data_normalizer * data) + kernel_epsilon
|
||||
else:
|
||||
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
|
||||
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
||||
data_dash = lax.dot_general(
|
||||
data_normalizer * data,
|
||||
data_thick_random_matrix,
|
||||
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
data_prime = kernel_fn(data_dash) + kernel_epsilon
|
||||
return data_prime
|
||||
|
||||
|
||||
def make_fast_softmax_attention(
|
||||
qkv_dim,
|
||||
renormalize_attention=True,
|
||||
numerical_stabilizer=0.000001,
|
||||
nb_features=256,
|
||||
ortho_features=True,
|
||||
ortho_scaling=0.0,
|
||||
redraw_features=True,
|
||||
unidirectional=False,
|
||||
nonnegative_features=True,
|
||||
lax_scan_unroll=1,
|
||||
):
|
||||
"""Construct a fast softmax attention method."""
|
||||
logging.info(
|
||||
"Fast softmax attention: %s features and orthogonal=%s, renormalize=%s",
|
||||
nb_features,
|
||||
ortho_features,
|
||||
renormalize_attention,
|
||||
)
|
||||
if ortho_features:
|
||||
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=ortho_scaling)
|
||||
else:
|
||||
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim)
|
||||
if nonnegative_features:
|
||||
|
||||
def kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True
|
||||
):
|
||||
return nonnegative_softmax_kernel_feature_creator(
|
||||
data,
|
||||
projection_matrix,
|
||||
attention_dims_t,
|
||||
batch_dims_t,
|
||||
precision,
|
||||
is_query,
|
||||
normalize_data,
|
||||
numerical_stabilizer,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True
|
||||
):
|
||||
del is_query
|
||||
return sincos_softmax_kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data
|
||||
)
|
||||
|
||||
attention_fn = FastAttentionviaLowRankDecomposition(
|
||||
matrix_creator,
|
||||
kernel_feature_creator,
|
||||
renormalize_attention=renormalize_attention,
|
||||
numerical_stabilizer=numerical_stabilizer,
|
||||
redraw_features=redraw_features,
|
||||
unidirectional=unidirectional,
|
||||
lax_scan_unroll=lax_scan_unroll,
|
||||
).dot_product_attention
|
||||
return attention_fn
|
||||
|
||||
|
||||
def make_fast_generalized_attention(
|
||||
qkv_dim,
|
||||
renormalize_attention=True,
|
||||
numerical_stabilizer=0.0,
|
||||
nb_features=256,
|
||||
features_type="deterministic",
|
||||
kernel_fn=jax.nn.relu,
|
||||
kernel_epsilon=0.001,
|
||||
redraw_features=False,
|
||||
unidirectional=False,
|
||||
lax_scan_unroll=1,
|
||||
):
|
||||
"""Construct a fast generalized attention menthod."""
|
||||
logging.info("Fast generalized attention.: %s features and renormalize=%s", nb_features, renormalize_attention)
|
||||
if features_type == "ortho":
|
||||
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False)
|
||||
elif features_type == "iid":
|
||||
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim)
|
||||
elif features_type == "deterministic":
|
||||
matrix_creator = None
|
||||
else:
|
||||
raise ValueError("Unknown feature value type")
|
||||
|
||||
def kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=False
|
||||
):
|
||||
del attention_dims_t
|
||||
del is_query
|
||||
return generalized_kernel_feature_creator(
|
||||
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data
|
||||
)
|
||||
|
||||
attention_fn = FastAttentionviaLowRankDecomposition(
|
||||
matrix_creator,
|
||||
kernel_feature_creator,
|
||||
renormalize_attention=renormalize_attention,
|
||||
numerical_stabilizer=numerical_stabilizer,
|
||||
redraw_features=redraw_features,
|
||||
unidirectional=unidirectional,
|
||||
lax_scan_unroll=lax_scan_unroll,
|
||||
).dot_product_attention
|
||||
return attention_fn
|
||||
|
||||
|
||||
class RandomMatrix(object):
|
||||
r"""
|
||||
Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D
|
||||
random arrays.
|
||||
"""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_2d_array(self):
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
|
||||
class GaussianUnstructuredRandomMatrix(RandomMatrix):
|
||||
def __init__(self, nb_rows, nb_columns, key):
|
||||
self.nb_rows = nb_rows
|
||||
self.nb_columns = nb_columns
|
||||
self.key = key
|
||||
|
||||
def get_2d_array(self):
|
||||
return random.normal(self.key, (self.nb_rows, self.nb_columns))
|
||||
|
||||
|
||||
class GaussianOrthogonalRandomMatrix(RandomMatrix):
|
||||
r"""
|
||||
Class providing a method to create Gaussian orthogonal matrix. Class is responsible for constructing 2D Gaussian
|
||||
orthogonal arrays.
|
||||
"""
|
||||
|
||||
def __init__(self, nb_rows, nb_columns, key, scaling=0):
|
||||
self.nb_rows = nb_rows
|
||||
self.nb_columns = nb_columns
|
||||
self.key = key
|
||||
self.scaling = scaling
|
||||
|
||||
def get_2d_array(self):
|
||||
nb_full_blocks = int(self.nb_rows / self.nb_columns)
|
||||
block_list = []
|
||||
rng = self.key
|
||||
for _ in range(nb_full_blocks):
|
||||
rng, rng_input = jax.random.split(rng)
|
||||
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns))
|
||||
q, _ = jnp.linalg.qr(unstructured_block)
|
||||
q = jnp.transpose(q)
|
||||
block_list.append(q)
|
||||
remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns
|
||||
if remaining_rows > 0:
|
||||
rng, rng_input = jax.random.split(rng)
|
||||
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns))
|
||||
q, _ = jnp.linalg.qr(unstructured_block)
|
||||
q = jnp.transpose(q)
|
||||
block_list.append(q[0:remaining_rows])
|
||||
final_matrix = jnp.vstack(block_list)
|
||||
|
||||
if self.scaling == 0:
|
||||
multiplier = jnp.linalg.norm(random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1)
|
||||
elif self.scaling == 1:
|
||||
multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows))
|
||||
else:
|
||||
raise ValueError("Scaling must be one of {0, 1}. Was %s" % self._scaling)
|
||||
|
||||
return jnp.matmul(jnp.diag(multiplier), final_matrix)
|
||||
|
||||
|
||||
class FastAttention(object):
|
||||
r"""
|
||||
Abstract class providing a method for fast attention. Class is responsible for providing a method
|
||||
<dot_product_attention> for fast approximate attention.
|
||||
"""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def dot_product_attention(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dtype=jnp.float32,
|
||||
bias=None,
|
||||
axis=None,
|
||||
broadcast_dropout=True,
|
||||
dropout_rng=None,
|
||||
dropout_rate=0.0,
|
||||
deterministic=False,
|
||||
precision=None,
|
||||
):
|
||||
"""
|
||||
Computes dot-product attention given query, key, and value. This is the core function for applying fast
|
||||
approximate dot-product attention. It calculates the attention weights given query and key and combines the
|
||||
values using the attention weights. This function supports multi-dimensional inputs
|
||||
|
||||
Args:
|
||||
query: queries for calculating attention with shape of [batch_size, dim1,
|
||||
dim2, ..., dimN, num_heads, mem_channels].
|
||||
key: keys for calculating attention with shape of [batch_size, dim1, dim2,
|
||||
..., dimN, num_heads, mem_channels].
|
||||
value: values to be used in attention with shape of [batch_size, dim1,
|
||||
dim2,..., dimN, num_heads, value_channels].
|
||||
dtype: the dtype of the computation (default: float32)
|
||||
bias: bias for the attention weights. This can be used for incorporating
|
||||
autoregressive mask, padding mask, proximity bias.
|
||||
axis: axises over which the attention is applied.
|
||||
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
||||
dropout_rng: JAX PRNGKey: to be used for dropout.
|
||||
dropout_rate: dropout rate.
|
||||
deterministic: bool, deterministic or not (to apply dropout).
|
||||
precision: numerical precision of the computation see `jax.lax.Precision`
|
||||
for details
|
||||
|
||||
Returns:
|
||||
Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels].
|
||||
"""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
|
||||
def _numerator(z_slice_shape, precision, unroll=1):
|
||||
def fwd(qs, ks, vs):
|
||||
def body(p, qkv):
|
||||
(q, k, v) = qkv
|
||||
p += jnp.einsum("...m,...d->...md", k, v, precision=precision)
|
||||
X_slice = jnp.einsum("...m,...md->...d", q, p, precision=precision)
|
||||
return p, X_slice
|
||||
|
||||
init_value = jnp.zeros(z_slice_shape)
|
||||
p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll)
|
||||
return W, (p, qs, ks, vs)
|
||||
|
||||
def bwd(pqkv, W_ct):
|
||||
def body(carry, qkv_xct):
|
||||
p, p_ct = carry
|
||||
q, k, v, x_ct = qkv_xct
|
||||
q_ct = jnp.einsum("...d,...md->...m", x_ct, p, precision=precision)
|
||||
p_ct += jnp.einsum("...d,...m->...md", x_ct, q, precision=precision)
|
||||
k_ct = jnp.einsum("...md,...d->...m", p_ct, v, precision=precision)
|
||||
v_ct = jnp.einsum("...md,...m->...d", p_ct, k, precision=precision)
|
||||
p -= jnp.einsum("...m,...d->...md", k, v, precision=precision)
|
||||
return (p, p_ct), (q_ct, k_ct, v_ct)
|
||||
|
||||
p, qs, ks, vs = pqkv
|
||||
_, (qs_ct, ks_ct, vs_ct) = lax.scan(
|
||||
body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct), reverse=True, unroll=unroll
|
||||
)
|
||||
return qs_ct, ks_ct, vs_ct
|
||||
|
||||
@jax.custom_vjp
|
||||
def _numerator_impl(qs, ks, vs):
|
||||
W, _ = fwd(qs, ks, vs)
|
||||
return W
|
||||
|
||||
_numerator_impl.defvjp(fwd, bwd)
|
||||
|
||||
return _numerator_impl
|
||||
|
||||
|
||||
def _denominator(t_slice_shape, precision, unroll=1):
|
||||
def fwd(qs, ks):
|
||||
def body(p, qk):
|
||||
q, k = qk
|
||||
p += k
|
||||
x = jnp.einsum("...m,...m->...", q, p, precision=precision)
|
||||
return p, x
|
||||
|
||||
p = jnp.zeros(t_slice_shape)
|
||||
p, R = lax.scan(body, p, (qs, ks), unroll=unroll)
|
||||
return R, (qs, ks, p)
|
||||
|
||||
def bwd(qkp, R_ct):
|
||||
def body(carry, qkx):
|
||||
p, p_ct = carry
|
||||
q, k, x_ct = qkx
|
||||
q_ct = jnp.einsum("...,...m->...m", x_ct, p, precision=precision)
|
||||
p_ct += jnp.einsum("...,...m->...m", x_ct, q, precision=precision)
|
||||
k_ct = p_ct
|
||||
p -= k
|
||||
return (p, p_ct), (q_ct, k_ct)
|
||||
|
||||
qs, ks, p = qkp
|
||||
_, (qs_ct, ks_ct) = lax.scan(body, (p, jnp.zeros_like(p)), (qs, ks, R_ct), reverse=True, unroll=unroll)
|
||||
return (qs_ct, ks_ct)
|
||||
|
||||
@jax.custom_vjp
|
||||
def _denominator_impl(qs, ks):
|
||||
R, _ = fwd(qs, ks)
|
||||
return R
|
||||
|
||||
_denominator_impl.defvjp(fwd, bwd)
|
||||
|
||||
return _denominator_impl
|
||||
|
||||
|
||||
class FastAttentionviaLowRankDecomposition(FastAttention):
|
||||
r"""
|
||||
Class providing a method for fast attention via low rank decomposition. Class is responsible for providing a method
|
||||
<dot_product_attention> for fast dot-product attention with the use of low rank decomposition (e.g. with random
|
||||
feature maps).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matrix_creator,
|
||||
kernel_feature_creator,
|
||||
renormalize_attention,
|
||||
numerical_stabilizer,
|
||||
redraw_features,
|
||||
unidirectional,
|
||||
lax_scan_unroll=1,
|
||||
): # For optimal GPU performance, set to 16.
|
||||
rng = random.PRNGKey(0)
|
||||
self.matrix_creator = matrix_creator
|
||||
self.projection_matrix = self.draw_weights(rng)
|
||||
self.kernel_feature_creator = kernel_feature_creator
|
||||
self.renormalize_attention = renormalize_attention
|
||||
self.numerical_stabilizer = numerical_stabilizer
|
||||
self.redraw_features = redraw_features
|
||||
self.unidirectional = unidirectional
|
||||
self.lax_scan_unroll = lax_scan_unroll
|
||||
|
||||
def draw_weights(self, key):
|
||||
if self.matrix_creator is None:
|
||||
return None
|
||||
matrixrng, _ = random.split(key)
|
||||
projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array()
|
||||
return projection_matrix
|
||||
|
||||
def dot_product_attention(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dtype=jnp.float32,
|
||||
bias=None,
|
||||
axis=None,
|
||||
broadcast_dropout=True,
|
||||
dropout_rng=None,
|
||||
dropout_rate=0.0,
|
||||
deterministic=False,
|
||||
precision=None,
|
||||
):
|
||||
|
||||
assert key.shape[:-1] == value.shape[:-1]
|
||||
assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]
|
||||
if axis is None:
|
||||
axis = tuple(range(1, key.ndim - 2))
|
||||
if not isinstance(axis, Iterable):
|
||||
axis = (axis,)
|
||||
assert key.ndim == query.ndim
|
||||
assert key.ndim == value.ndim
|
||||
for ax in axis:
|
||||
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
|
||||
raise ValueError("Attention axis must be between the batch " "axis and the last-two axes.")
|
||||
n = key.ndim
|
||||
|
||||
# Constructing projection tensor.
|
||||
if self.redraw_features:
|
||||
# TODO(kchoro): Get rid of the constant below.
|
||||
query_seed = lax.convert_element_type(jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
|
||||
rng = random.PRNGKey(query_seed)
|
||||
self.projection_matrix = self.draw_weights(rng)
|
||||
|
||||
# batch_dims is <bs, <non-attention dims>, num_heads>
|
||||
batch_dims = tuple(onp.delete(range(n), axis + (n - 1,)))
|
||||
# q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
||||
qk_perm = batch_dims + axis + (n - 1,)
|
||||
k_extra_perm = axis + batch_dims + (n - 1,)
|
||||
key_extra = key.transpose(k_extra_perm)
|
||||
key = key.transpose(qk_perm)
|
||||
query = query.transpose(qk_perm)
|
||||
# v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
||||
v_perm = batch_dims + axis + (n - 1,)
|
||||
value = value.transpose(v_perm)
|
||||
batch_dims_t = tuple(range(len(batch_dims)))
|
||||
attention_dims_t = tuple(range(len(batch_dims), len(batch_dims) + len(axis)))
|
||||
|
||||
# Constructing tensors Q^{'} and K^{'}.
|
||||
query_prime = self.kernel_feature_creator(
|
||||
query, self.projection_matrix, attention_dims_t, batch_dims_t, precision, True
|
||||
)
|
||||
key_prime = self.kernel_feature_creator(
|
||||
key, self.projection_matrix, attention_dims_t, batch_dims_t, precision, False
|
||||
)
|
||||
|
||||
if self.unidirectional:
|
||||
index = attention_dims_t[0]
|
||||
z_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) + (value.shape[-1],)
|
||||
|
||||
numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll)
|
||||
W = numerator_fn(
|
||||
jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0)
|
||||
)
|
||||
|
||||
# Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
|
||||
W = jnp.moveaxis(W, 0, index)
|
||||
|
||||
if not self.renormalize_attention:
|
||||
# Unidirectional, not-normalized attention.
|
||||
perm_inv = _invert_perm(qk_perm)
|
||||
result = W.transpose(perm_inv)
|
||||
return result
|
||||
else:
|
||||
# Unidirectional, normalized attention.
|
||||
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)])
|
||||
|
||||
index = attention_dims_t[0]
|
||||
t_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],)
|
||||
denominator_fn = _denominator(t_slice_shape, precision, self.lax_scan_unroll)
|
||||
R = denominator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0))
|
||||
|
||||
R = jnp.moveaxis(R, 0, index)
|
||||
else:
|
||||
contract_query = tuple(range(len(batch_dims) + len(axis), len(batch_dims) + len(axis) + 1))
|
||||
contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
|
||||
# Constructing Z = (K^{'})^{T}V
|
||||
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
|
||||
Z = lax.dot_general(
|
||||
key_prime,
|
||||
value,
|
||||
((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
# Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
|
||||
# q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
|
||||
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
|
||||
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
|
||||
W = lax.dot_general(
|
||||
query_prime, Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), precision=precision
|
||||
)
|
||||
if not self.renormalize_attention:
|
||||
# Bidirectional, not-normalized attention.
|
||||
perm_inv = _invert_perm(qk_perm)
|
||||
result = W.transpose(perm_inv)
|
||||
return result
|
||||
else:
|
||||
# Bidirectional, normalized attention.
|
||||
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)])
|
||||
contract_key = tuple(range(len(batch_dims), len(batch_dims) + len(axis)))
|
||||
contract_thick_all_ones = tuple(range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim))
|
||||
# Construct T = (K^{'})^{T} 1_L
|
||||
# k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
||||
T = lax.dot_general(
|
||||
key_prime,
|
||||
thick_all_ones,
|
||||
((contract_key, contract_thick_all_ones), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
# Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
|
||||
# q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
|
||||
# T (bs, <non-attention dims>, num_heads, channels_m)
|
||||
R = lax.dot_general(
|
||||
query_prime,
|
||||
T,
|
||||
(((query_prime.ndim - 1,), (T.ndim - 1,)), (batch_dims_t, range(0, len(T.shape) - 1))),
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <= self.numerical_stabilizer)
|
||||
R = jnp.reciprocal(R)
|
||||
R = jnp.expand_dims(R, len(R.shape))
|
||||
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
|
||||
# R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
|
||||
result = W * R
|
||||
# back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
|
||||
perm_inv = _invert_perm(qk_perm)
|
||||
result = result.transpose(perm_inv)
|
||||
return result
|
||||
|
||||
|
||||
def _invert_perm(perm):
|
||||
perm_inv = [0] * len(perm)
|
||||
for i, j in enumerate(perm):
|
||||
perm_inv[j] = i
|
||||
return tuple(perm_inv)
|
||||
685
examples/research_projects/performer/run_mlm_performer.py
Normal file
685
examples/research_projects/performer/run_mlm_performer.py
Normal file
@@ -0,0 +1,685 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
|
||||
text file or a dataset.
|
||||
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=masked-lm
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import jax_utils
|
||||
from flax.optim import Adam
|
||||
from flax.training import common_utils
|
||||
from flax.training.common_utils import get_metrics
|
||||
from jax.nn import log_softmax
|
||||
from modeling_flax_performer import FlaxPerformerForMaskedLM
|
||||
from transformers import (
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
FlaxBertForMaskedLM,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
TensorType,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
# Cache the result
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
|
||||
|
||||
else:
|
||||
print(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandbArguments:
|
||||
"""
|
||||
Arguments for logging
|
||||
"""
|
||||
|
||||
wandb_user_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The WandB user name for potential logging. If left None, no logging"},
|
||||
)
|
||||
wandb_project_name: Optional[str] = field(
|
||||
default="performer-experiments",
|
||||
metadata={"help": "The WandB project name for potential logging"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||
"""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
},
|
||||
)
|
||||
performer: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use FAVOR+ attention"},
|
||||
)
|
||||
reinitialize: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use a blank model without pretraining"},
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
train_ref_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
||||
)
|
||||
validation_ref_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
validation_split_percentage: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={
|
||||
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
||||
},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated. Default to the max input length of the model."
|
||||
},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
mlm_probability: float = field(
|
||||
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||
|
||||
|
||||
# Adapted from transformers/data/data_collator.py
|
||||
# Letting here for now, let's discuss where it should live
|
||||
@dataclass
|
||||
class FlaxDataCollatorForLanguageModeling:
|
||||
"""
|
||||
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
||||
are not all of the same length.
|
||||
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
|
||||
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
|
||||
non-masked tokens and the value to predict for the masked token.
|
||||
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
||||
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
|
||||
|
||||
.. note::
|
||||
|
||||
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
||||
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
||||
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
||||
argument :obj:`return_special_tokens_mask=True`.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
mlm: bool = True
|
||||
mlm_probability: float = 0.15
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mlm and self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||
"You should pass `mlm=False` to train on causal language modeling instead."
|
||||
)
|
||||
|
||||
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
|
||||
|
||||
# If special token mask has been preprocessed, pop it from the dict.
|
||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||
if self.mlm:
|
||||
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||
)
|
||||
else:
|
||||
labels = batch["input_ids"].copy()
|
||||
if self.tokenizer.pad_token_id is not None:
|
||||
labels[labels == self.tokenizer.pad_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
def mask_tokens(
|
||||
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||
"""
|
||||
labels = inputs.copy()
|
||||
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
||||
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
||||
special_tokens_mask = special_tokens_mask.astype("bool")
|
||||
|
||||
probability_matrix[special_tokens_mask] = 0.0
|
||||
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
# 10% of the time, we replace masked input tokens with random word
|
||||
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
|
||||
indices_random &= masked_indices & ~indices_replaced
|
||||
|
||||
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
|
||||
def create_learning_rate_scheduler(
|
||||
factors="constant * linear_warmup * rsqrt_decay",
|
||||
base_learning_rate=0.5,
|
||||
warmup_steps=1000,
|
||||
decay_factor=0.5,
|
||||
steps_per_decay=20000,
|
||||
steps_per_cycle=100000,
|
||||
):
|
||||
"""Creates learning rate schedule.
|
||||
Interprets factors in the factors string which can consist of:
|
||||
* constant: interpreted as the constant value,
|
||||
* linear_warmup: interpreted as linear warmup until warmup_steps,
|
||||
* rsqrt_decay: divide by square root of max(step, warmup_steps)
|
||||
* rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
|
||||
* decay_every: Every k steps decay the learning rate by decay_factor.
|
||||
* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
|
||||
Args:
|
||||
factors: string, factors separated by "*" that defines the schedule.
|
||||
base_learning_rate: float, the starting constant for the lr schedule.
|
||||
warmup_steps: int, how many steps to warm up for in the warmup schedule.
|
||||
decay_factor: float, the amount to decay the learning rate by.
|
||||
steps_per_decay: int, how often to decay the learning rate.
|
||||
steps_per_cycle: int, steps per cycle when using cosine decay.
|
||||
Returns:
|
||||
a function learning_rate(step): float -> {"learning_rate": float}, the
|
||||
step-dependent lr.
|
||||
"""
|
||||
factors = [n.strip() for n in factors.split("*")]
|
||||
|
||||
def step_fn(step):
|
||||
"""Step to learning rate function."""
|
||||
ret = 1.0
|
||||
for name in factors:
|
||||
if name == "constant":
|
||||
ret *= base_learning_rate
|
||||
elif name == "linear_warmup":
|
||||
ret *= jnp.minimum(1.0, step / warmup_steps)
|
||||
elif name == "rsqrt_decay":
|
||||
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
|
||||
elif name == "rsqrt_normalized_decay":
|
||||
ret *= jnp.sqrt(warmup_steps)
|
||||
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
|
||||
elif name == "decay_every":
|
||||
ret *= decay_factor ** (step // steps_per_decay)
|
||||
elif name == "cosine_decay":
|
||||
progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle))
|
||||
ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
|
||||
else:
|
||||
raise ValueError("Unknown factor %s." % name)
|
||||
return jnp.asarray(ret, dtype=jnp.float32)
|
||||
|
||||
return step_fn
|
||||
|
||||
|
||||
def compute_metrics(logits, labels, weights, label_smoothing=0.0):
|
||||
"""Compute summary metrics."""
|
||||
loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing)
|
||||
acc, _ = accuracy(logits, labels, weights)
|
||||
metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer}
|
||||
metrics = jax.lax.psum(metrics, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
|
||||
def accuracy(logits, targets, weights=None):
|
||||
"""Compute weighted accuracy for log probs and targets.
|
||||
Args:
|
||||
logits: [batch, length, num_classes] float array.
|
||||
targets: categorical targets [batch, length] int array.
|
||||
weights: None or array of shape [batch, length]
|
||||
Returns:
|
||||
Tuple of scalar loss and batch normalizing factor.
|
||||
"""
|
||||
if logits.ndim != targets.ndim + 1:
|
||||
raise ValueError(
|
||||
"Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
|
||||
)
|
||||
|
||||
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
|
||||
loss *= weights
|
||||
|
||||
return loss.sum(), weights.sum()
|
||||
|
||||
|
||||
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0):
|
||||
"""Compute cross entropy and entropy for log probs and targets.
|
||||
Args:
|
||||
logits: [batch, length, num_classes] float array.
|
||||
targets: categorical targets [batch, length] int array.
|
||||
weights: None or array of shape [batch, length]
|
||||
label_smoothing: label smoothing constant, used to determine the on and off values.
|
||||
Returns:
|
||||
Tuple of scalar loss and batch normalizing factor.
|
||||
"""
|
||||
if logits.ndim != targets.ndim + 1:
|
||||
raise ValueError(
|
||||
"Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
|
||||
)
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
confidence = 1.0 - label_smoothing
|
||||
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||
normalizing_constant = -(
|
||||
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||
)
|
||||
soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||
|
||||
loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1)
|
||||
loss = loss - normalizing_constant
|
||||
|
||||
if weights is not None:
|
||||
loss = loss * weights
|
||||
normalizing_factor = weights.sum()
|
||||
else:
|
||||
normalizing_factor = np.prod(targets.shape)
|
||||
|
||||
return loss.sum(), normalizing_factor
|
||||
|
||||
|
||||
def training_step(optimizer, batch, dropout_rng):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
|
||||
def loss_fn(params):
|
||||
targets = batch.pop("labels")
|
||||
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
|
||||
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss, weight_sum = cross_entropy(logits, targets, token_mask)
|
||||
return loss / weight_sum
|
||||
|
||||
step = optimizer.state.step
|
||||
lr = lr_scheduler_fn(step)
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grad = grad_fn(optimizer.target)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
|
||||
|
||||
return loss, optimizer, new_dropout_rng
|
||||
|
||||
|
||||
def eval_step(params, batch):
|
||||
"""
|
||||
Calculate evaluation metrics on a batch.
|
||||
"""
|
||||
targets = batch.pop("labels")
|
||||
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
return compute_metrics(logits, targets, token_mask)
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
||||
nb_samples = len(samples_idx)
|
||||
samples_to_remove = nb_samples % batch_size
|
||||
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = nb_samples // batch_size
|
||||
batch_idx = np.split(samples_idx, sections_split)
|
||||
return batch_idx
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, WandbArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args, wandb_args = parser.parse_json_file(
|
||||
json_file=os.path.abspath(sys.argv[1])
|
||||
)
|
||||
else:
|
||||
model_args, data_args, training_args, wandb_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
level="NOTSET",
|
||||
datefmt="[%X]",
|
||||
)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
|
||||
if "validation" not in datasets.keys():
|
||||
datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||
)
|
||||
datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
if extension == "txt":
|
||||
extension = "text"
|
||||
datasets = load_dataset(extension, data_files=data_files)
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
config = BertConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
lm_class = FlaxPerformerForMaskedLM if model_args.performer else FlaxBertForMaskedLM
|
||||
if model_args.reinitialize:
|
||||
model = lm_class(config=BertConfig.from_pretrained(model_args.model_name_or_path))
|
||||
else:
|
||||
model = lm_class.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
dtype=jnp.float32,
|
||||
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
|
||||
seed=training_args.seed,
|
||||
dropout_rate=0.1,
|
||||
)
|
||||
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# First we tokenize all the texts.
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
else:
|
||||
column_names = datasets["validation"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Remove empty lines
|
||||
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
||||
return tokenizer(
|
||||
examples,
|
||||
return_special_tokens_mask=True,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
max_length=data_args.max_seq_length,
|
||||
)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
input_columns=[text_column_name],
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = Adam(
|
||||
learning_rate=training_args.learning_rate,
|
||||
weight_decay=training_args.weight_decay,
|
||||
beta1=training_args.adam_beta1,
|
||||
beta2=training_args.adam_beta2,
|
||||
).create(model.params)
|
||||
|
||||
# Create learning rate scheduler
|
||||
lr_scheduler_fn = create_learning_rate_scheduler(
|
||||
base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1)
|
||||
)
|
||||
|
||||
# Create parallel version of the training and evaluation steps
|
||||
p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,))
|
||||
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Replicate the optimizer on each device
|
||||
optimizer = jax_utils.replicate(optimizer)
|
||||
|
||||
# Store some constant
|
||||
nb_epochs = int(training_args.num_train_epochs)
|
||||
batch_size = int(training_args.train_batch_size)
|
||||
eval_batch_size = int(training_args.eval_batch_size)
|
||||
|
||||
if wandb_args.wandb_user_name is not None:
|
||||
import wandb
|
||||
|
||||
wandb.init(project=wandb_args.wandb_project_name, entity=wandb_args.wandb_user_name)
|
||||
|
||||
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
|
||||
# ======================== Training ================================
|
||||
# Create sampling rng
|
||||
rng, training_rng, eval_rng = jax.random.split(rng, 3)
|
||||
|
||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||
nb_training_samples = len(tokenized_datasets["train"])
|
||||
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
|
||||
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
|
||||
|
||||
# Gather the indexes for creating the batch and do a training step
|
||||
for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1):
|
||||
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = common_utils.shard(model_inputs.data)
|
||||
loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs)
|
||||
|
||||
if wandb_args.wandb_user_name is not None:
|
||||
wandb.log({"Training loss": np.array(loss).mean()})
|
||||
|
||||
epochs.write(f"Loss: {loss}")
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
nb_eval_samples = len(tokenized_datasets["validation"])
|
||||
eval_samples_idx = jnp.arange(nb_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = common_utils.shard(model_inputs.data)
|
||||
metrics = p_eval_step(optimizer.target, model_inputs)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
eval_metrics_np = get_metrics(eval_metrics)
|
||||
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
|
||||
eval_normalizer = eval_metrics_np.pop("normalizer")
|
||||
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
|
||||
|
||||
# Update progress bar
|
||||
epochs.desc = (
|
||||
f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})"
|
||||
)
|
||||
|
||||
if wandb_args.wandb_user_name is not None:
|
||||
wandb.log({"Eval loss": np.array(eval_summary["loss"]).mean()})
|
||||
|
||||
# Save metrics
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
for name, value in eval_summary.items():
|
||||
summary_writer.scalar(name, value, epoch)
|
||||
1
examples/research_projects/performer/sanity_script.sh
Executable file
1
examples/research_projects/performer/sanity_script.sh
Executable file
@@ -0,0 +1 @@
|
||||
TOKENIZERS_PARALLELISM=true python run_mlm_performer.py --output_dir experiments --dataset_name wikipedia --dataset_config_name 20200501.simple --model_name_or_path bert-base-cased --tokenizer_name bert-base-cased --do_train --overwrite_output_dir --per_device_train_batch_size 4 --learning_rate 5e-4 --warmup_steps 100 --num_train_epochs 3 --performer
|
||||
Reference in New Issue
Block a user