[Flax] Adapt Flax models to new structure (#9484)
* Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Fix code quality * Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016 * Remove redundant ElectraPooler * save intermediate * adapt * correct bert flax design * adapt roberta as well * finish roberta flax * finish * apply suggestions * apply suggestions Co-authored-by: Chris Nguyen <anhtu2687@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5c0bf39782
commit
0b98ca368f
@@ -97,6 +97,7 @@ class FlaxBertLayerNorm(nn.Module):
|
|||||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
hidden_size: int
|
||||||
epsilon: float = 1e-6
|
epsilon: float = 1e-6
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
bias: bool = True # If True, bias (beta) is added.
|
bias: bool = True # If True, bias (beta) is added.
|
||||||
@@ -106,7 +107,10 @@ class FlaxBertLayerNorm(nn.Module):
|
|||||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
|
||||||
|
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
"""
|
"""
|
||||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
||||||
@@ -119,18 +123,17 @@ class FlaxBertLayerNorm(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Normalized inputs (the same shape as inputs).
|
Normalized inputs (the same shape as inputs).
|
||||||
"""
|
"""
|
||||||
features = x.shape[-1]
|
|
||||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||||
var = mean2 - jax.lax.square(mean)
|
var = mean2 - jax.lax.square(mean)
|
||||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||||
|
|
||||||
if self.scale:
|
if self.scale:
|
||||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
|
mul = mul * jnp.asarray(self.gamma)
|
||||||
y = (x - mean) * mul
|
y = (x - mean) * mul
|
||||||
|
|
||||||
if self.bias:
|
if self.bias:
|
||||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
|
y = y + jnp.asarray(self.beta)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@@ -142,278 +145,232 @@ class FlaxBertEmbedding(nn.Module):
|
|||||||
|
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
hidden_size: int
|
hidden_size: int
|
||||||
kernel_init_scale: float = 0.2
|
initializer_range: float
|
||||||
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, inputs):
|
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
|
||||||
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
|
||||||
return jnp.take(embedding, inputs, axis=0)
|
|
||||||
|
def __call__(self, input_ids):
|
||||||
|
return jnp.take(self.embeddings, input_ids, axis=0)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertEmbeddings(nn.Module):
|
class FlaxBertEmbeddings(nn.Module):
|
||||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|
||||||
vocab_size: int
|
config: BertConfig
|
||||||
hidden_size: int
|
|
||||||
type_vocab_size: int
|
|
||||||
max_length: int
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
self.word_embeddings = FlaxBertEmbedding(
|
||||||
|
self.config.vocab_size,
|
||||||
# Embed
|
self.config.hidden_size,
|
||||||
w_emb = FlaxBertEmbedding(
|
initializer_range=self.config.initializer_range,
|
||||||
self.vocab_size,
|
|
||||||
self.hidden_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
name="word_embeddings",
|
name="word_embeddings",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(jnp.atleast_2d(input_ids.astype("i4")))
|
)
|
||||||
p_emb = FlaxBertEmbedding(
|
self.position_embeddings = FlaxBertEmbedding(
|
||||||
self.max_length,
|
self.config.max_position_embeddings,
|
||||||
self.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
initializer_range=self.config.initializer_range,
|
||||||
name="position_embeddings",
|
name="position_embeddings",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(jnp.atleast_2d(position_ids.astype("i4")))
|
)
|
||||||
t_emb = FlaxBertEmbedding(
|
self.token_type_embeddings = FlaxBertEmbedding(
|
||||||
self.type_vocab_size,
|
self.config.type_vocab_size,
|
||||||
self.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
initializer_range=self.config.initializer_range,
|
||||||
name="token_type_embeddings",
|
name="token_type_embeddings",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(jnp.atleast_2d(token_type_ids.astype("i4")))
|
)
|
||||||
|
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||||
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||||
|
# Embed
|
||||||
|
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
|
||||||
|
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
|
||||||
|
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||||
|
|
||||||
# Sum all embeddings
|
# Sum all embeddings
|
||||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||||
|
|
||||||
# Layer Norm
|
# Layer Norm
|
||||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
return embeddings
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertAttention(nn.Module):
|
class FlaxBertAttention(nn.Module):
|
||||||
num_heads: int
|
config: BertConfig
|
||||||
head_size: int
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
self.self_attention = nn.attention.SelfAttention(
|
||||||
|
num_heads=self.config.num_attention_heads,
|
||||||
|
qkv_features=self.config.hidden_size,
|
||||||
|
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
bias_init=jax.nn.initializers.zeros,
|
||||||
|
name="self",
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||||
self_att = nn.attention.SelfAttention(
|
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
num_heads=self.num_heads,
|
|
||||||
qkv_features=self.head_size,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
deterministic=deterministic,
|
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
|
||||||
bias_init=jax.nn.initializers.zeros,
|
|
||||||
name="self",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(hidden_states, attention_mask)
|
|
||||||
|
|
||||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
hidden_states = self.layer_norm(self_attn_output + hidden_states)
|
||||||
return layer_norm
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertIntermediate(nn.Module):
|
class FlaxBertIntermediate(nn.Module):
|
||||||
output_size: int
|
config: BertConfig
|
||||||
hidden_act: str = "gelu"
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, hidden_states):
|
self.dense = nn.Dense(
|
||||||
hidden_states = nn.Dense(
|
self.config.intermediate_size,
|
||||||
features=self.output_size,
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_states)
|
)
|
||||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
self.activation = ACT2FN[self.config.hidden_act]
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertOutput(nn.Module):
|
class FlaxBertOutput(nn.Module):
|
||||||
dropout_rate: float = 0.0
|
config: BertConfig
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
self.dense = nn.Dense(
|
||||||
hidden_states = nn.Dense(
|
self.config.hidden_size,
|
||||||
attention_output.shape[-1],
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(intermediate_output)
|
)
|
||||||
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
|
hidden_states = self.layer_norm(hidden_states + attention_output)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertLayer(nn.Module):
|
class FlaxBertLayer(nn.Module):
|
||||||
num_heads: int
|
config: BertConfig
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype)
|
||||||
attention = FlaxBertAttention(
|
self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype)
|
||||||
self.num_heads,
|
self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype)
|
||||||
self.head_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="attention",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
|
||||||
intermediate = FlaxBertIntermediate(
|
|
||||||
self.intermediate_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
name="intermediate",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(attention)
|
|
||||||
output = FlaxBertOutput(
|
|
||||||
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
|
||||||
)(intermediate, attention, deterministic=deterministic)
|
|
||||||
|
|
||||||
return output
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
|
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
|
hidden_states = self.intermediate(attention_output)
|
||||||
|
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertLayerCollection(nn.Module):
|
class FlaxBertLayerCollection(nn.Module):
|
||||||
"""
|
config: BertConfig
|
||||||
Stores N BertLayer(s)
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_layers: int
|
|
||||||
num_heads: int
|
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, inputs, attention_mask, deterministic: bool = True):
|
self.layers = [
|
||||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
|
||||||
# Initialize input / output
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
input_i = inputs
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
# Forward over all encoders
|
return hidden_states
|
||||||
for i in range(self.num_layers):
|
|
||||||
layer = FlaxBertLayer(
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
name=f"{i}",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
input_i = layer(input_i, attention_mask, deterministic=deterministic)
|
|
||||||
return input_i
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertEncoder(nn.Module):
|
class FlaxBertEncoder(nn.Module):
|
||||||
num_layers: int
|
config: BertConfig
|
||||||
num_heads: int
|
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
layer = FlaxBertLayerCollection(
|
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
self.num_layers,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="layer",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertPooler(nn.Module):
|
class FlaxBertPooler(nn.Module):
|
||||||
kernel_init_scale: float = 0.2
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, hidden_states):
|
self.dense = nn.Dense(
|
||||||
cls_token = hidden_states[:, 0]
|
self.config.hidden_size,
|
||||||
out = nn.Dense(
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
hidden_states.shape[-1],
|
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(cls_token)
|
)
|
||||||
return nn.tanh(out)
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
cls_hidden_state = hidden_states[:, 0]
|
||||||
|
cls_hidden_state = self.dense(cls_hidden_state)
|
||||||
|
return nn.tanh(cls_hidden_state)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertPredictionHeadTransform(nn.Module):
|
class FlaxBertPredictionHeadTransform(nn.Module):
|
||||||
hidden_act: str = "gelu"
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype)
|
||||||
|
self.activation = ACT2FN[self.config.hidden_act]
|
||||||
|
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states):
|
||||||
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
hidden_states = self.activation(hidden_states)
|
||||||
return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states)
|
return self.layer_norm(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertLMPredictionHead(nn.Module):
|
class FlaxBertLMPredictionHead(nn.Module):
|
||||||
vocab_size: int
|
config: BertConfig
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype)
|
||||||
|
self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states):
|
||||||
# TODO: The output weights are the same as the input embeddings, but there is
|
# TODO: The output weights are the same as the input embeddings, but there is
|
||||||
# an output-only bias for each token.
|
# an output-only bias for each token.
|
||||||
# Need a link between the two variables so that the bias is correctly
|
# Need a link between the two variables so that the bias is correctly
|
||||||
# resized with `resize_token_embeddings`
|
# resized with `resize_token_embeddings`
|
||||||
|
hidden_states = self.transform(hidden_states)
|
||||||
hidden_states = FlaxBertPredictionHeadTransform(
|
hidden_states = self.decoder(hidden_states)
|
||||||
name="transform", hidden_act=self.hidden_act, dtype=self.dtype
|
|
||||||
)(hidden_states)
|
|
||||||
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertOnlyMLMHead(nn.Module):
|
class FlaxBertOnlyMLMHead(nn.Module):
|
||||||
vocab_size: int
|
config: BertConfig
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states):
|
||||||
hidden_states = FlaxBertLMPredictionHead(
|
hidden_states = self.mlm_head(hidden_states)
|
||||||
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype
|
|
||||||
)(hidden_states)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -543,20 +500,7 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
):
|
):
|
||||||
module = FlaxBertModule(
|
module = FlaxBertModule(config=config, dtype=dtype, **kwargs)
|
||||||
vocab_size=config.vocab_size,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
type_vocab_size=config.type_vocab_size,
|
|
||||||
max_length=config.max_position_embeddings,
|
|
||||||
num_encoder_layers=config.num_hidden_layers,
|
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
head_size=config.hidden_size,
|
|
||||||
intermediate_size=config.intermediate_size,
|
|
||||||
dropout_rate=config.hidden_dropout_prob,
|
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
dtype=dtype,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
@@ -592,71 +536,34 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlaxBertModule(nn.Module):
|
class FlaxBertModule(nn.Module):
|
||||||
vocab_size: int
|
config: BertConfig
|
||||||
hidden_size: int
|
|
||||||
type_vocab_size: int
|
|
||||||
max_length: int
|
|
||||||
num_encoder_layers: int
|
|
||||||
num_heads: int
|
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
add_pooling_layer: bool = True
|
add_pooling_layer: bool = True
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype)
|
||||||
|
self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype)
|
||||||
|
self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||||
|
|
||||||
# Embedding
|
hidden_states = self.embeddings(
|
||||||
embeddings = FlaxBertEmbeddings(
|
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||||
self.vocab_size,
|
)
|
||||||
self.hidden_size,
|
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
self.type_vocab_size,
|
|
||||||
self.max_length,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="embeddings",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
|
||||||
|
|
||||||
# N stacked encoding layers
|
|
||||||
encoder = FlaxBertEncoder(
|
|
||||||
self.num_encoder_layers,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
name="encoder",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(embeddings, attention_mask, deterministic=deterministic)
|
|
||||||
|
|
||||||
if not self.add_pooling_layer:
|
if not self.add_pooling_layer:
|
||||||
return encoder
|
return hidden_states
|
||||||
|
|
||||||
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
pooled = self.pooler(hidden_states)
|
||||||
return encoder, pooled
|
return hidden_states, pooled
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
):
|
):
|
||||||
module = FlaxBertForMaskedLMModule(
|
module = FlaxBertForMaskedLMModule(config, **kwargs)
|
||||||
vocab_size=config.vocab_size,
|
|
||||||
type_vocab_size=config.type_vocab_size,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
intermediate_size=config.intermediate_size,
|
|
||||||
head_size=config.hidden_size,
|
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
num_encoder_layers=config.num_hidden_layers,
|
|
||||||
max_length=config.max_position_embeddings,
|
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
@@ -691,43 +598,32 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlaxBertForMaskedLMModule(nn.Module):
|
class FlaxBertForMaskedLMModule(nn.Module):
|
||||||
vocab_size: int
|
config: BertConfig
|
||||||
hidden_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
head_size: int
|
|
||||||
num_heads: int
|
|
||||||
num_encoder_layers: int
|
|
||||||
type_vocab_size: int
|
|
||||||
max_length: int
|
|
||||||
hidden_act: str
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.encoder = FlaxBertModule(
|
||||||
|
config=self.config,
|
||||||
|
add_pooling_layer=False,
|
||||||
|
name="bert",
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
self.mlm_head = FlaxBertOnlyMLMHead(
|
||||||
|
config=self.config,
|
||||||
|
name="cls",
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
):
|
):
|
||||||
# Model
|
# Model
|
||||||
encoder = FlaxBertModule(
|
hidden_states = self.encoder(
|
||||||
vocab_size=self.vocab_size,
|
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||||
type_vocab_size=self.type_vocab_size,
|
)
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
head_size=self.hidden_size,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_encoder_layers=self.num_encoder_layers,
|
|
||||||
max_length=self.max_length,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
dtype=self.dtype,
|
|
||||||
add_pooling_layer=False,
|
|
||||||
name="bert",
|
|
||||||
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
|
||||||
|
|
||||||
# Compute the prediction scores
|
# Compute the prediction scores
|
||||||
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
logits = FlaxBertOnlyMLMHead(
|
logits = self.mlm_head(hidden_states)
|
||||||
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
|
|
||||||
)(encoder)
|
|
||||||
|
|
||||||
return (logits,)
|
return (logits,)
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ class FlaxRobertaLayerNorm(nn.Module):
|
|||||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
hidden_size: int
|
||||||
epsilon: float = 1e-6
|
epsilon: float = 1e-6
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
bias: bool = True # If True, bias (beta) is added.
|
bias: bool = True # If True, bias (beta) is added.
|
||||||
@@ -123,7 +124,10 @@ class FlaxRobertaLayerNorm(nn.Module):
|
|||||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
|
||||||
|
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
"""
|
"""
|
||||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
||||||
@@ -136,18 +140,17 @@ class FlaxRobertaLayerNorm(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Normalized inputs (the same shape as inputs).
|
Normalized inputs (the same shape as inputs).
|
||||||
"""
|
"""
|
||||||
features = x.shape[-1]
|
|
||||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||||
var = mean2 - jax.lax.square(mean)
|
var = mean2 - jax.lax.square(mean)
|
||||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||||
|
|
||||||
if self.scale:
|
if self.scale:
|
||||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
|
mul = mul * jnp.asarray(self.gamma)
|
||||||
y = (x - mean) * mul
|
y = (x - mean) * mul
|
||||||
|
|
||||||
if self.bias:
|
if self.bias:
|
||||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
|
y = y + jnp.asarray(self.beta)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@@ -160,243 +163,202 @@ class FlaxRobertaEmbedding(nn.Module):
|
|||||||
|
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
hidden_size: int
|
hidden_size: int
|
||||||
kernel_init_scale: float = 0.2
|
initializer_range: float
|
||||||
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, inputs):
|
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
|
||||||
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
|
||||||
return jnp.take(embedding, inputs, axis=0)
|
|
||||||
|
def __call__(self, input_ids):
|
||||||
|
return jnp.take(self.embeddings, input_ids, axis=0)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
||||||
class FlaxRobertaEmbeddings(nn.Module):
|
class FlaxRobertaEmbeddings(nn.Module):
|
||||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|
||||||
vocab_size: int
|
config: RobertaConfig
|
||||||
hidden_size: int
|
|
||||||
type_vocab_size: int
|
|
||||||
max_length: int
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
self.word_embeddings = FlaxRobertaEmbedding(
|
||||||
|
self.config.vocab_size,
|
||||||
# Embed
|
self.config.hidden_size,
|
||||||
w_emb = FlaxRobertaEmbedding(
|
initializer_range=self.config.initializer_range,
|
||||||
self.vocab_size,
|
|
||||||
self.hidden_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
name="word_embeddings",
|
name="word_embeddings",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(jnp.atleast_2d(input_ids.astype("i4")))
|
)
|
||||||
p_emb = FlaxRobertaEmbedding(
|
self.position_embeddings = FlaxRobertaEmbedding(
|
||||||
self.max_length,
|
self.config.max_position_embeddings,
|
||||||
self.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
initializer_range=self.config.initializer_range,
|
||||||
name="position_embeddings",
|
name="position_embeddings",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(jnp.atleast_2d(position_ids.astype("i4")))
|
)
|
||||||
t_emb = FlaxRobertaEmbedding(
|
self.token_type_embeddings = FlaxRobertaEmbedding(
|
||||||
self.type_vocab_size,
|
self.config.type_vocab_size,
|
||||||
self.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
initializer_range=self.config.initializer_range,
|
||||||
name="token_type_embeddings",
|
name="token_type_embeddings",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(jnp.atleast_2d(token_type_ids.astype("i4")))
|
)
|
||||||
|
self.layer_norm = FlaxRobertaLayerNorm(
|
||||||
|
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||||
|
# Embed
|
||||||
|
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
|
||||||
|
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
|
||||||
|
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||||
|
|
||||||
# Sum all embeddings
|
# Sum all embeddings
|
||||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||||
|
|
||||||
# Layer Norm
|
# Layer Norm
|
||||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
return embeddings
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
||||||
class FlaxRobertaAttention(nn.Module):
|
class FlaxRobertaAttention(nn.Module):
|
||||||
num_heads: int
|
config: RobertaConfig
|
||||||
head_size: int
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
self.self_attention = nn.attention.SelfAttention(
|
||||||
|
num_heads=self.config.num_attention_heads,
|
||||||
|
qkv_features=self.config.hidden_size,
|
||||||
|
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
bias_init=jax.nn.initializers.zeros,
|
||||||
|
name="self",
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
self.layer_norm = FlaxRobertaLayerNorm(
|
||||||
|
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||||
self_att = nn.attention.SelfAttention(
|
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
num_heads=self.num_heads,
|
|
||||||
qkv_features=self.head_size,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
deterministic=deterministic,
|
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
|
||||||
bias_init=jax.nn.initializers.zeros,
|
|
||||||
name="self",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(hidden_states, attention_mask)
|
|
||||||
|
|
||||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
hidden_states = self.layer_norm(self_attn_output + hidden_states)
|
||||||
return layer_norm
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||||
class FlaxRobertaIntermediate(nn.Module):
|
class FlaxRobertaIntermediate(nn.Module):
|
||||||
output_size: int
|
config: RobertaConfig
|
||||||
hidden_act: str = "gelu"
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, hidden_states):
|
self.dense = nn.Dense(
|
||||||
hidden_states = nn.Dense(
|
self.config.intermediate_size,
|
||||||
features=self.output_size,
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_states)
|
)
|
||||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
self.activation = ACT2FN[self.config.hidden_act]
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
||||||
class FlaxRobertaOutput(nn.Module):
|
class FlaxRobertaOutput(nn.Module):
|
||||||
dropout_rate: float = 0.0
|
config: RobertaConfig
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
self.dense = nn.Dense(
|
||||||
hidden_states = nn.Dense(
|
self.config.hidden_size,
|
||||||
attention_output.shape[-1],
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(intermediate_output)
|
)
|
||||||
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
self.layer_norm = FlaxRobertaLayerNorm(
|
||||||
|
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
|
hidden_states = self.layer_norm(hidden_states + attention_output)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
|
||||||
class FlaxRobertaLayer(nn.Module):
|
class FlaxRobertaLayer(nn.Module):
|
||||||
num_heads: int
|
config: RobertaConfig
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype)
|
||||||
attention = FlaxRobertaAttention(
|
self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype)
|
||||||
self.num_heads,
|
self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype)
|
||||||
self.head_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="attention",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
|
||||||
intermediate = FlaxRobertaIntermediate(
|
|
||||||
self.intermediate_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
name="intermediate",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(attention)
|
|
||||||
output = FlaxRobertaOutput(
|
|
||||||
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
|
||||||
)(intermediate, attention, deterministic=deterministic)
|
|
||||||
|
|
||||||
return output
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
|
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
|
hidden_states = self.intermediate(attention_output)
|
||||||
|
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
||||||
class FlaxRobertaLayerCollection(nn.Module):
|
class FlaxRobertaLayerCollection(nn.Module):
|
||||||
"""
|
config: RobertaConfig
|
||||||
Stores N RobertaLayer(s)
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_layers: int
|
|
||||||
num_heads: int
|
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, inputs, attention_mask, deterministic: bool = True):
|
self.layers = [
|
||||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
|
||||||
# Initialize input / output
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
input_i = inputs
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
# Forward over all encoders
|
return hidden_states
|
||||||
for i in range(self.num_layers):
|
|
||||||
layer = FlaxRobertaLayer(
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
name=f"{i}",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
input_i = layer(input_i, attention_mask, deterministic=deterministic)
|
|
||||||
return input_i
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
||||||
class FlaxRobertaEncoder(nn.Module):
|
class FlaxRobertaEncoder(nn.Module):
|
||||||
num_layers: int
|
config: RobertaConfig
|
||||||
num_heads: int
|
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
layer = FlaxRobertaLayerCollection(
|
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
self.num_layers,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="layer",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||||
class FlaxRobertaPooler(nn.Module):
|
class FlaxRobertaPooler(nn.Module):
|
||||||
kernel_init_scale: float = 0.2
|
config: RobertaConfig
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
def __call__(self, hidden_states):
|
self.dense = nn.Dense(
|
||||||
cls_token = hidden_states[:, 0]
|
self.config.hidden_size,
|
||||||
out = nn.Dense(
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
hidden_states.shape[-1],
|
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(cls_token)
|
)
|
||||||
return nn.tanh(out)
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
cls_hidden_state = hidden_states[:, 0]
|
||||||
|
cls_hidden_state = self.dense(cls_hidden_state)
|
||||||
|
return nn.tanh(cls_hidden_state)
|
||||||
|
|
||||||
|
|
||||||
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||||
@@ -520,21 +482,7 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
|||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = FlaxRobertaModule(
|
module = FlaxRobertaModule(config, dtype=dtype, **kwargs)
|
||||||
vocab_size=config.vocab_size,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
type_vocab_size=config.type_vocab_size,
|
|
||||||
max_length=config.max_position_embeddings,
|
|
||||||
num_encoder_layers=config.num_hidden_layers,
|
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
head_size=config.hidden_size,
|
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
intermediate_size=config.intermediate_size,
|
|
||||||
dropout_rate=config.hidden_dropout_prob,
|
|
||||||
dtype=dtype,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@@ -570,50 +518,24 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
|||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||||
class FlaxRobertaModule(nn.Module):
|
class FlaxRobertaModule(nn.Module):
|
||||||
vocab_size: int
|
config: RobertaConfig
|
||||||
hidden_size: int
|
|
||||||
type_vocab_size: int
|
|
||||||
max_length: int
|
|
||||||
num_encoder_layers: int
|
|
||||||
num_heads: int
|
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
hidden_act: str = "gelu"
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
add_pooling_layer: bool = True
|
add_pooling_layer: bool = True
|
||||||
|
|
||||||
@nn.compact
|
def setup(self):
|
||||||
|
self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype)
|
||||||
|
self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype)
|
||||||
|
self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||||
|
|
||||||
# Embedding
|
hidden_states = self.embeddings(
|
||||||
embeddings = FlaxRobertaEmbeddings(
|
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||||
self.vocab_size,
|
)
|
||||||
self.hidden_size,
|
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
self.type_vocab_size,
|
|
||||||
self.max_length,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="embeddings",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
|
||||||
|
|
||||||
# N stacked encoding layers
|
|
||||||
encoder = FlaxRobertaEncoder(
|
|
||||||
self.num_encoder_layers,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
name="encoder",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(embeddings, attention_mask, deterministic=deterministic)
|
|
||||||
|
|
||||||
if not self.add_pooling_layer:
|
if not self.add_pooling_layer:
|
||||||
return encoder
|
return hidden_states
|
||||||
|
|
||||||
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
pooled = self.pooler(hidden_states)
|
||||||
return encoder, pooled
|
return hidden_states, pooled
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None):
|
|||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
class FlaxModelTesterMixin:
|
class FlaxModelTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
@@ -90,7 +91,7 @@ class FlaxModelTesterMixin:
|
|||||||
fx_outputs = fx_model(**inputs_dict)
|
fx_outputs = fx_model(**inputs_dict)
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
@@ -103,7 +104,6 @@ class FlaxModelTesterMixin:
|
|||||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||||
|
|
||||||
@require_flax
|
|
||||||
def test_from_pretrained_save_pretrained(self):
|
def test_from_pretrained_save_pretrained(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -121,7 +121,6 @@ class FlaxModelTesterMixin:
|
|||||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||||
|
|
||||||
@require_flax
|
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -144,7 +143,6 @@ class FlaxModelTesterMixin:
|
|||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
@require_flax
|
|
||||||
def test_naming_convention(self):
|
def test_naming_convention(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model_class_name = model_class.__name__
|
model_class_name = model_class.__name__
|
||||||
|
|||||||
Reference in New Issue
Block a user