Flax Masked Language Modeling training example (#8728)

* Remove "Model" suffix from Flax models to look more 🤗

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Initial working (forward + backward) for Flax MLM training example.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Simply code

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing comments, using module and moving to LM task.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Restore parameter name "module" wrongly renamed model.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Restore correct output ordering...

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Actually commit the example 😅

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Add FlaxBertModelForMaskedLM after rebasing.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make it possible to initialize the training from scratch

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Reuse flax linen example of cross entropy loss

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added specific data collator for flax

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove todo for data collator

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added evaluation step

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added ability to provide dtype to support bfloat16 on TPU

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable flax tensorboard output

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable jax.pmap support.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Ensure batches are correctly sized to be dispatched with jax.pmap

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable bfloat16 with --fp16 cmdline args

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Correctly export metrics to tensorboard

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added dropout and ability to use it.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Effectively enable & disable during training and evaluation steps.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Oops.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable specifying kernel initializer scale

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added warmup step to the learning rate scheduler.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix typo.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Print training loss

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fix linter issue (flake8)

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix model matching

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix dummies

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix non default dtype on Flax models

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Use the same create_position_ids_from_input_ids for FlaxRoberta

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make Roberta attention as Bert

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fix copy

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Wording.

Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
This commit is contained in:
Funtowicz Morgan
2020-12-09 17:13:56 +01:00
committed by GitHub
parent df2af6d8b8
commit 75627148ee
8 changed files with 1187 additions and 132 deletions

View File

@@ -59,4 +59,4 @@ if is_tf_available():
)
if is_flax_available():
from .modeling_flax_bert import FlaxBertModel
from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel

View File

@@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict
from typing import Callable, Dict, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
@@ -101,8 +102,8 @@ class FlaxBertLayerNorm(nn.Module):
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
@nn.compact
def __call__(self, x):
@@ -122,11 +123,13 @@ class FlaxBertLayerNorm(nn.Module):
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
return y
@@ -138,7 +141,9 @@ class FlaxBertEmbedding(nn.Module):
vocab_size: int
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
kernel_init_scale: float = 0.2
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs):
@@ -153,63 +158,105 @@ class FlaxBertEmbeddings(nn.Module):
hidden_size: int
type_vocab_size: int
max_length: int
kernel_init_scale: float = 0.2
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
w_emb = FlaxBertEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
jnp.atleast_2d(input_ids.astype("i4"))
)
p_emb = FlaxBertEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
jnp.atleast_2d(position_ids.astype("i4"))
)
t_emb = FlaxBertEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
jnp.atleast_2d(token_type_ids.astype("i4"))
)
w_emb = FlaxBertEmbedding(
self.vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="word_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(input_ids.astype("i4")))
p_emb = FlaxBertEmbedding(
self.max_length,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="position_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(position_ids.astype("i4")))
t_emb = FlaxBertEmbedding(
self.type_vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="token_type_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(token_type_ids.astype("i4")))
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
# Layer Norm
layer_norm = FlaxBertLayerNorm(name="layer_norm")(summed_emb)
return layer_norm
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
return embeddings
class FlaxBertAttention(nn.Module):
num_heads: int
head_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
)
self_att = nn.attention.SelfAttention(
num_heads=self.num_heads,
qkv_features=self.head_size,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_state, attention_mask)
layer_norm = FlaxBertLayerNorm(name="layer_norm")(self_att + hidden_state)
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
return layer_norm
class FlaxBertIntermediate(nn.Module):
output_size: int
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
dense = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_state)
return gelu(dense)
class FlaxBertOutput(nn.Module):
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_state = nn.Dense(
attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
hidden_state = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
return hidden_state
@@ -217,12 +264,26 @@ class FlaxBertLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
output = FlaxBertOutput(name="output")(intermediate, attention)
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
attention = FlaxBertAttention(
self.num_heads,
self.head_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
intermediate = FlaxBertIntermediate(
self.intermediate_size, kernel_init_scale=self.kernel_init_scale, name="intermediate", dtype=self.dtype
)(attention)
output = FlaxBertOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
)(intermediate, attention, deterministic=deterministic)
return output
@@ -236,9 +297,12 @@ class FlaxBertLayerCollection(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs, attention_mask):
def __call__(self, inputs, attention_mask, deterministic: bool = True):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
# Initialize input / output
@@ -246,8 +310,16 @@ class FlaxBertLayerCollection(nn.Module):
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxBertLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
input_i = layer(input_i, attention_mask)
layer = FlaxBertLayer(
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name=f"{i}",
dtype=self.dtype,
)
input_i = layer(input_i, attention_mask, deterministic=deterministic)
return input_i
@@ -256,21 +328,39 @@ class FlaxBertEncoder(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
layer = FlaxBertLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
)(hidden_state, attention_mask)
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
return layer
class FlaxBertPooler(nn.Module):
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
return jax.lax.tanh(out)
out = nn.Dense(
hidden_state.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(cls_token)
return nn.tanh(out)
class FlaxBertModule(nn.Module):
@@ -282,24 +372,104 @@ class FlaxBertModule(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxBertEmbeddings(
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
)(input_ids, token_type_ids, position_ids, attention_mask)
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxBertEncoder(
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
)(embeddings, attention_mask)
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
pooled = FlaxBertPooler(name="pooler")(encoder)
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
class FlaxBertPredictionHeadTransform(nn.Module):
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, hidden_states):
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
hidden_states = nn.elu(hidden_states) # TODO: ACT2FN[config.hidden_act]
return FlaxBertLayerNorm(name="LayerNorm", dtype=self.dtype)(hidden_states)
class FlaxBertLMPredictionHead(nn.Module):
vocab_size: int
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, hidden_states):
# TODO: The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
# Need a link between the two variables so that the bias is correctly
# resized with `resize_token_embeddings`
hidden_states = FlaxBertPredictionHeadTransform(name="transform", dtype=self.dtype)(hidden_states)
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
return hidden_states
class FlaxBertOnlyMLMHead(nn.Module):
vocab_size: int
hidden_size: int
intermediate_size: int
head_size: int
num_heads: int
num_encoder_layers: int
type_vocab_size: int
max_length: int
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
encoder, pooled = FlaxBertModule(
vocab_size=self.vocab_size,
type_vocab_size=self.type_vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
head_size=self.hidden_size,
num_heads=self.num_heads,
num_encoder_layers=self.num_encoder_layers,
max_length=self.max_length,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
# Compute the prediction scores
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
logits = FlaxBertLMPredictionHead(vocab_size=self.vocab_size, name="predictions", dtype=self.dtype)(encoder)
return logits, pooled
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
@@ -385,8 +555,8 @@ class FlaxBertModel(FlaxPreTrainedModel):
return jax_state
def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs):
model = FlaxBertModule(
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
module = FlaxBertModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
@@ -395,16 +565,43 @@ class FlaxBertModel(FlaxPreTrainedModel):
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
dtype=dtype,
)
super().__init__(config, model, state, seed)
@property
def module(self) -> nn.Module:
return self._module
super().__init__(config, module, state, seed)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
@@ -414,10 +611,62 @@ class FlaxBertModel(FlaxPreTrainedModel):
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return self.model.apply(
{"params": self.params},
return input_ids, attention_mask, token_type_ids, position_ids
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
class FlaxBertForMaskedLM(FlaxBertModel):
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
super().__init__(config, state, seed, dtype)
self._module = FlaxBertOnlyMLMHead(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
head_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_encoder_layers=config.num_hidden_layers,
max_length=config.max_length,
**kwargs,
)
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
pooled, logits = self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
return logits, pooled