Flax Masked Language Modeling training example (#8728)
* Remove "Model" suffix from Flax models to look more 🤗 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Initial working (forward + backward) for Flax MLM training example. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Simply code Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Addressing comments, using module and moving to LM task. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Restore parameter name "module" wrongly renamed model. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Restore correct output ordering... Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Actually commit the example 😅 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Add FlaxBertModelForMaskedLM after rebasing. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make it possible to initialize the training from scratch Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reuse flax linen example of cross entropy loss Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added specific data collator for flax Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove todo for data collator Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added evaluation step Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added ability to provide dtype to support bfloat16 on TPU Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable flax tensorboard output Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable jax.pmap support. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure batches are correctly sized to be dispatched with jax.pmap Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable bfloat16 with --fp16 cmdline args Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correctly export metrics to tensorboard Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added dropout and ability to use it. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Effectively enable & disable during training and evaluation steps. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Oops. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable specifying kernel initializer scale Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added warmup step to the learning rate scheduler. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix typo. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Print training loss Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix linter issue (flake8) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix model matching Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix dummies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix non default dtype on Flax models Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same create_position_ids_from_input_ids for FlaxRoberta Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta attention as Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix copy Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Wording. Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
This commit is contained in:
@@ -59,4 +59,4 @@ if is_tf_available():
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_bert import FlaxBertModel
|
||||
from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
|
||||
@@ -13,13 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||
@@ -101,8 +102,8 @@ class FlaxBertLayerNorm(nn.Module):
|
||||
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
|
||||
# (also e.g. nn.relu), this can be disabled since the scaling will be
|
||||
# done by the next layer.
|
||||
bias_init: jnp.ndarray = nn.initializers.zeros
|
||||
scale_init: jnp.ndarray = nn.initializers.ones
|
||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
@@ -122,11 +123,13 @@ class FlaxBertLayerNorm(nn.Module):
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
|
||||
y = (x - mean) * mul
|
||||
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
|
||||
return y
|
||||
|
||||
|
||||
@@ -138,7 +141,9 @@ class FlaxBertEmbedding(nn.Module):
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
|
||||
kernel_init_scale: float = 0.2
|
||||
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
@@ -153,63 +158,105 @@ class FlaxBertEmbeddings(nn.Module):
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
kernel_init_scale: float = 0.2
|
||||
dropout_rate: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
|
||||
# Embed
|
||||
w_emb = FlaxBertEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
|
||||
jnp.atleast_2d(input_ids.astype("i4"))
|
||||
)
|
||||
p_emb = FlaxBertEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
|
||||
jnp.atleast_2d(position_ids.astype("i4"))
|
||||
)
|
||||
t_emb = FlaxBertEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
|
||||
jnp.atleast_2d(token_type_ids.astype("i4"))
|
||||
)
|
||||
w_emb = FlaxBertEmbedding(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
name="word_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(input_ids.astype("i4")))
|
||||
p_emb = FlaxBertEmbedding(
|
||||
self.max_length,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
name="position_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(position_ids.astype("i4")))
|
||||
t_emb = FlaxBertEmbedding(
|
||||
self.type_vocab_size,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
name="token_type_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||
|
||||
# Sum all embeddings
|
||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||
|
||||
# Layer Norm
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm")(summed_emb)
|
||||
|
||||
return layer_norm
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
|
||||
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
|
||||
return embeddings
|
||||
|
||||
|
||||
class FlaxBertAttention(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
self_att = nn.attention.SelfAttention(
|
||||
num_heads=self.num_heads,
|
||||
qkv_features=self.head_size,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=deterministic,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask)
|
||||
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm")(self_att + hidden_state)
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
|
||||
return layer_norm
|
||||
|
||||
|
||||
class FlaxBertIntermediate(nn.Module):
|
||||
output_size: int
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
|
||||
dense = nn.Dense(
|
||||
features=self.output_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state)
|
||||
return gelu(dense)
|
||||
|
||||
|
||||
class FlaxBertOutput(nn.Module):
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output):
|
||||
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
|
||||
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
|
||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||
hidden_state = nn.Dense(
|
||||
attention_output.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(intermediate_output)
|
||||
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
|
||||
hidden_state = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
|
||||
|
||||
@@ -217,12 +264,26 @@ class FlaxBertLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
|
||||
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
|
||||
output = FlaxBertOutput(name="output")(intermediate, attention)
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
attention = FlaxBertAttention(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="attention",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
intermediate = FlaxBertIntermediate(
|
||||
self.intermediate_size, kernel_init_scale=self.kernel_init_scale, name="intermediate", dtype=self.dtype
|
||||
)(attention)
|
||||
output = FlaxBertOutput(
|
||||
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
||||
)(intermediate, attention, deterministic=deterministic)
|
||||
|
||||
return output
|
||||
|
||||
@@ -236,9 +297,12 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, attention_mask):
|
||||
def __call__(self, inputs, attention_mask, deterministic: bool = True):
|
||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||
|
||||
# Initialize input / output
|
||||
@@ -246,8 +310,16 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
|
||||
# Forward over all encoders
|
||||
for i in range(self.num_layers):
|
||||
layer = FlaxBertLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
|
||||
input_i = layer(input_i, attention_mask)
|
||||
layer = FlaxBertLayer(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name=f"{i}",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
input_i = layer(input_i, attention_mask, deterministic=deterministic)
|
||||
return input_i
|
||||
|
||||
|
||||
@@ -256,21 +328,39 @@ class FlaxBertEncoder(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
layer = FlaxBertLayerCollection(
|
||||
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
|
||||
)(hidden_state, attention_mask)
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="layer",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
return layer
|
||||
|
||||
|
||||
class FlaxBertPooler(nn.Module):
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
|
||||
return jax.lax.tanh(out)
|
||||
out = nn.Dense(
|
||||
hidden_state.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(cls_token)
|
||||
return nn.tanh(out)
|
||||
|
||||
|
||||
class FlaxBertModule(nn.Module):
|
||||
@@ -282,24 +372,104 @@ class FlaxBertModule(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxBertEmbeddings(
|
||||
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxBertEncoder(
|
||||
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
|
||||
)(embeddings, attention_mask)
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
|
||||
pooled = FlaxBertPooler(name="pooler")(encoder)
|
||||
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
class FlaxBertPredictionHeadTransform(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
|
||||
hidden_states = nn.elu(hidden_states) # TODO: ACT2FN[config.hidden_act]
|
||||
return FlaxBertLayerNorm(name="LayerNorm", dtype=self.dtype)(hidden_states)
|
||||
|
||||
|
||||
class FlaxBertLMPredictionHead(nn.Module):
|
||||
vocab_size: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states):
|
||||
# TODO: The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
# Need a link between the two variables so that the bias is correctly
|
||||
# resized with `resize_token_embeddings`
|
||||
|
||||
hidden_states = FlaxBertPredictionHeadTransform(name="transform", dtype=self.dtype)(hidden_states)
|
||||
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertOnlyMLMHead(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
head_size: int
|
||||
num_heads: int
|
||||
num_encoder_layers: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
dropout_rate: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
encoder, pooled = FlaxBertModule(
|
||||
vocab_size=self.vocab_size,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
head_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
max_length=self.max_length,
|
||||
dropout_rate=self.dropout_rate,
|
||||
dtype=self.dtype,
|
||||
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
|
||||
# Compute the prediction scores
|
||||
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
||||
logits = FlaxBertLMPredictionHead(vocab_size=self.vocab_size, name="predictions", dtype=self.dtype)(encoder)
|
||||
|
||||
return logits, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
@@ -385,8 +555,8 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
||||
|
||||
return jax_state
|
||||
|
||||
def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs):
|
||||
model = FlaxBertModule(
|
||||
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
|
||||
module = FlaxBertModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
@@ -395,16 +565,43 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
super().__init__(config, model, state, seed)
|
||||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
return self._module
|
||||
super().__init__(config, module, state, seed)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
@@ -414,10 +611,62 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return self.model.apply(
|
||||
{"params": self.params},
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
|
||||
class FlaxBertForMaskedLM(FlaxBertModel):
|
||||
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
||||
super().__init__(config, state, seed, dtype)
|
||||
|
||||
self._module = FlaxBertOnlyMLMHead(
|
||||
vocab_size=config.vocab_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
head_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
max_length=config.max_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
pooled, logits = self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
return logits, pooled
|
||||
|
||||
Reference in New Issue
Block a user