[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)

* save intermediate

* save intermediate

* save intermediate

* correct flax bert model file

* new module / model naming

* make style

* almost finish BERT

* finish roberta

* make fix-copies

* delete keys file

* last refactor

* fixes in run_mlm_flax.py

* remove pooled from run_mlm_flax.py`

* fix gelu | gelu_new

* remove Module from inits

* splits

* dirty print

* preventing warmup_steps == 0

* smaller splits

* make fix-copies

* dirty print

* dirty print

* initial_evaluation argument

* declaration order fix

* proper model initialization/loading

* proper initialization

* run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug

* removed tokenizers warning hack, fixed model re-initialization

* reverted training_args.py changes

* fix flax from pretrained

* improve test in flax

* apply sylvains tips

* update init

* make 0.3.0 compatible

* revert tevens changes

* revert tevens changes 2

* finalize revert

* fix bug

* add docs

* add pretrained to init

* Update src/transformers/modeling_flax_utils.py

* fix copies

* final improvements

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Patrick von Platen
2020-12-16 13:03:32 +01:00
committed by GitHub
parent 51adb97cd6
commit 640e6fe190
14 changed files with 700 additions and 359 deletions

View File

@@ -20,10 +20,11 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...utils import logging
from .configuration_bert import BertConfig
@@ -205,7 +206,7 @@ class FlaxBertAttention(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
@@ -219,27 +220,28 @@ class FlaxBertAttention(nn.Module):
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_state, attention_mask)
)(hidden_states, attention_mask)
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
return layer_norm
class FlaxBertIntermediate(nn.Module):
output_size: int
hidden_act: str = "gelu"
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(
def __call__(self, hidden_states):
hidden_states = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_state)
return gelu(dense)
)(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states)
return hidden_states
class FlaxBertOutput(nn.Module):
@@ -249,27 +251,28 @@ class FlaxBertOutput(nn.Module):
@nn.compact
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_state = nn.Dense(
hidden_states = nn.Dense(
attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
hidden_state = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
return hidden_state
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
return hidden_states
class FlaxBertLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention = FlaxBertAttention(
self.num_heads,
self.head_size,
@@ -277,9 +280,13 @@ class FlaxBertLayer(nn.Module):
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
)(hidden_states, attention_mask, deterministic=deterministic)
intermediate = FlaxBertIntermediate(
self.intermediate_size, kernel_init_scale=self.kernel_init_scale, name="intermediate", dtype=self.dtype
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
hidden_act=self.hidden_act,
name="intermediate",
dtype=self.dtype,
)(attention)
output = FlaxBertOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
@@ -297,6 +304,7 @@ class FlaxBertLayerCollection(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@@ -316,6 +324,7 @@ class FlaxBertLayerCollection(nn.Module):
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name=f"{i}",
dtype=self.dtype,
)
@@ -328,22 +337,24 @@ class FlaxBertEncoder(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
layer = FlaxBertLayerCollection(
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
hidden_act=self.hidden_act,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
)(hidden_states, attention_mask, deterministic=deterministic)
return layer
@@ -352,10 +363,10 @@ class FlaxBertPooler(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
def __call__(self, hidden_states):
cls_token = hidden_states[:, 0]
out = nn.Dense(
hidden_state.shape[-1],
hidden_states.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
@@ -363,62 +374,20 @@ class FlaxBertPooler(nn.Module):
return nn.tanh(out)
class FlaxBertModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxBertEmbeddings(
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxBertEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
class FlaxBertPredictionHeadTransform(nn.Module):
hidden_act: str = "gelu"
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, hidden_states):
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
hidden_states = nn.elu(hidden_states) # TODO: ACT2FN[config.hidden_act]
return FlaxBertLayerNorm(name="LayerNorm", dtype=self.dtype)(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states)
return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states)
class FlaxBertLMPredictionHead(nn.Module):
vocab_size: int
hidden_act: str = "gelu"
dtype: jnp.dtype = jnp.float32
@nn.compact
@@ -428,64 +397,57 @@ class FlaxBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly
# resized with `resize_token_embeddings`
hidden_states = FlaxBertPredictionHeadTransform(name="transform", dtype=self.dtype)(hidden_states)
hidden_states = FlaxBertPredictionHeadTransform(
name="transform", hidden_act=self.hidden_act, dtype=self.dtype
)(hidden_states)
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
return hidden_states
class FlaxBertOnlyMLMHead(nn.Module):
vocab_size: int
hidden_size: int
intermediate_size: int
head_size: int
num_heads: int
num_encoder_layers: int
type_vocab_size: int
max_length: int
dropout_rate: float = 0.0
hidden_act: str = "gelu"
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
encoder, pooled = FlaxBertModule(
vocab_size=self.vocab_size,
type_vocab_size=self.type_vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
head_size=self.hidden_size,
num_heads=self.num_heads,
num_encoder_layers=self.num_encoder_layers,
max_length=self.max_length,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
# Compute the prediction scores
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
logits = FlaxBertLMPredictionHead(vocab_size=self.vocab_size, name="predictions", dtype=self.dtype)(encoder)
return logits, pooled
def __call__(self, hidden_states):
hidden_states = FlaxBertLMPredictionHead(
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype
)(hidden_states)
return hidden_states
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class FlaxBertModel(FlaxPreTrainedModel):
class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
model_class = FlaxBertModule
config_class = BertConfig
base_model_prefix = "bert"
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
jax_state = dict(pt_state)
@@ -501,6 +463,11 @@ class FlaxBertModel(FlaxPreTrainedModel):
key = key.replace("weight", "kernel")
jax_state[key] = tensor
if "decoder.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor.T
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
@@ -526,7 +493,7 @@ class FlaxBertModel(FlaxPreTrainedModel):
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
@@ -539,6 +506,11 @@ class FlaxBertModel(FlaxPreTrainedModel):
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Hack to correctly load some pytorch models
if "predictions.bias" in key:
del jax_state[key]
jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
@@ -555,7 +527,22 @@ class FlaxBertModel(FlaxPreTrainedModel):
return jax_state
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class FlaxBertModel(FlaxBertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
@@ -566,10 +553,12 @@ class FlaxBertModel(FlaxPreTrainedModel):
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
hidden_act=config.hidden_act,
dtype=dtype,
**kwargs,
)
super().__init__(config, module, state, seed)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
@@ -601,34 +590,62 @@ class FlaxBertModel(FlaxPreTrainedModel):
rngs=rngs,
)
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
class FlaxBertModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
return input_ids, attention_mask, token_type_ids, position_ids
# Embedding
embeddings = FlaxBertEmbeddings(
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
# N stacked encoding layers
encoder = FlaxBertEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
if not self.add_pooling_layer:
return encoder
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
class FlaxBertForMaskedLM(FlaxBertModel):
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
super().__init__(config, state, seed, dtype)
self._module = FlaxBertOnlyMLMHead(
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForMaskedLMModule(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
@@ -636,10 +653,13 @@ class FlaxBertForMaskedLM(FlaxBertModel):
head_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_encoder_layers=config.num_hidden_layers,
max_length=config.max_length,
max_length=config.max_position_embeddings,
hidden_act=config.hidden_act,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def __call__(
self,
input_ids,
@@ -659,7 +679,7 @@ class FlaxBertForMaskedLM(FlaxBertModel):
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
pooled, logits = self.module.apply(
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
@@ -669,4 +689,45 @@ class FlaxBertForMaskedLM(FlaxBertModel):
rngs=rngs,
)
return logits, pooled
class FlaxBertForMaskedLMModule(nn.Module):
vocab_size: int
hidden_size: int
intermediate_size: int
head_size: int
num_heads: int
num_encoder_layers: int
type_vocab_size: int
max_length: int
hidden_act: str
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
encoder = FlaxBertModule(
vocab_size=self.vocab_size,
type_vocab_size=self.type_vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
head_size=self.hidden_size,
num_heads=self.num_heads,
num_encoder_layers=self.num_encoder_layers,
max_length=self.max_length,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
dtype=self.dtype,
add_pooling_layer=False,
name="bert",
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
# Compute the prediction scores
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
logits = FlaxBertOnlyMLMHead(
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
)(encoder)
return (logits,)