Flax Masked Language Modeling training example (#8728)

* Remove "Model" suffix from Flax models to look more 🤗

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Initial working (forward + backward) for Flax MLM training example.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Simply code

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing comments, using module and moving to LM task.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Restore parameter name "module" wrongly renamed model.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Restore correct output ordering...

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Actually commit the example 😅

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Add FlaxBertModelForMaskedLM after rebasing.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make it possible to initialize the training from scratch

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Reuse flax linen example of cross entropy loss

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added specific data collator for flax

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove todo for data collator

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added evaluation step

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added ability to provide dtype to support bfloat16 on TPU

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable flax tensorboard output

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable jax.pmap support.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Ensure batches are correctly sized to be dispatched with jax.pmap

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable bfloat16 with --fp16 cmdline args

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Correctly export metrics to tensorboard

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added dropout and ability to use it.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Effectively enable & disable during training and evaluation steps.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Oops.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable specifying kernel initializer scale

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added warmup step to the learning rate scheduler.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix typo.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Print training loss

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fix linter issue (flake8)

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix model matching

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix dummies

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix non default dtype on Flax models

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Use the same create_position_ids_from_input_ids for FlaxRoberta

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make Roberta attention as Bert

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fix copy

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Wording.

Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
This commit is contained in:
Funtowicz Morgan
2020-12-09 17:13:56 +01:00
committed by GitHub
parent df2af6d8b8
commit 75627148ee
8 changed files with 1187 additions and 132 deletions

View File

@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict
from typing import Callable, Dict, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
@@ -101,8 +102,8 @@ class FlaxRobertaLayerNorm(nn.Module):
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
@nn.compact
def __call__(self, x):
@@ -122,11 +123,13 @@ class FlaxRobertaLayerNorm(nn.Module):
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
return y
@@ -139,7 +142,9 @@ class FlaxRobertaEmbedding(nn.Module):
vocab_size: int
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
kernel_init_scale: float = 0.2
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs):
@@ -155,66 +160,108 @@ class FlaxRobertaEmbeddings(nn.Module):
hidden_size: int
type_vocab_size: int
max_length: int
kernel_init_scale: float = 0.2
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
jnp.atleast_2d(input_ids.astype("i4"))
)
p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
jnp.atleast_2d(position_ids.astype("i4"))
)
t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
jnp.atleast_2d(token_type_ids.astype("i4"))
)
w_emb = FlaxRobertaEmbedding(
self.vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="word_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(input_ids.astype("i4")))
p_emb = FlaxRobertaEmbedding(
self.max_length,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="position_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(position_ids.astype("i4")))
t_emb = FlaxRobertaEmbedding(
self.type_vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
name="token_type_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(token_type_ids.astype("i4")))
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
# Layer Norm
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb)
return layer_norm
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
return embeddings
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
num_heads: int
head_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
)
self_att = nn.attention.SelfAttention(
num_heads=self.num_heads,
qkv_features=self.head_size,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_state, attention_mask)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
return layer_norm
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module):
output_size: int
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
dense = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_state)
return gelu(dense)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module):
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_state = nn.Dense(
attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
return hidden_state
@@ -222,14 +269,29 @@ class FlaxRobertaLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
hidden_state, attention_mask
)
intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention)
output = FlaxRobertaOutput(name="output")(intermediate, attention)
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
attention = FlaxRobertaAttention(
self.num_heads,
self.head_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
intermediate = FlaxRobertaIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
name="intermediate",
dtype=self.dtype,
)(attention)
output = FlaxRobertaOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
)(intermediate, attention, deterministic=deterministic)
return output
@@ -244,9 +306,12 @@ class FlaxRobertaLayerCollection(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs, attention_mask):
def __call__(self, inputs, attention_mask, deterministic: bool = True):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
# Initialize input / output
@@ -254,8 +319,16 @@ class FlaxRobertaLayerCollection(nn.Module):
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
input_i = layer(input_i, attention_mask)
layer = FlaxRobertaLayer(
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name=f"{i}",
dtype=self.dtype,
)
input_i = layer(input_i, attention_mask, deterministic=deterministic)
return input_i
@@ -265,22 +338,40 @@ class FlaxRobertaEncoder(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask):
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
layer = FlaxRobertaLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
)(hidden_state, attention_mask)
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
return layer
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module):
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
return jax.lax.tanh(out)
out = nn.Dense(
hidden_state.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(cls_token)
return nn.tanh(out)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
@@ -293,21 +384,38 @@ class FlaxRobertaModule(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxRobertaEmbeddings(
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
)(input_ids, token_type_ids, position_ids, attention_mask)
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
)(embeddings, attention_mask)
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
pooled = FlaxRobertaPooler(name="pooler")(encoder)
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
@@ -396,8 +504,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
return jax_state
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs):
model = FlaxRobertaModule(
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
module = FlaxRobertaModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
@@ -406,31 +514,78 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
dtype=dtype,
)
super().__init__(config, model, state, seed)
@property
def module(self) -> nn.Module:
return self._module
super().__init__(config, module, state, seed)
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
def __call__(
self,
input_ids,
token_type_ids=None,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
if position_ids is None:
position_ids = jnp.arange(
self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return self.model.apply(
{"params": self.params},
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: jnp.ndarray
padding_idx: int
Returns: jnp.ndarray
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).astype("i4")
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
return incremental_indices.astype("i4") + padding_idx