From 0b98ca368f88aeb03334c7e5dfac375cb1b069d2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 18 Mar 2021 09:44:17 +0300 Subject: [PATCH] [Flax] Adapt Flax models to new structure (#9484) * Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Fix code quality * Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016 * Remove redundant ElectraPooler * save intermediate * adapt * correct bert flax design * adapt roberta as well * finish roberta flax * finish * apply suggestions * apply suggestions Co-authored-by: Chris Nguyen --- .../models/bert/modeling_flax_bert.py | 452 +++++++----------- .../models/roberta/modeling_flax_roberta.py | 358 ++++++-------- tests/test_modeling_flax_common.py | 6 +- 3 files changed, 316 insertions(+), 500 deletions(-) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 9def58dad9..97a219f12c 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -97,6 +97,7 @@ class FlaxBertLayerNorm(nn.Module): Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. """ + hidden_size: int epsilon: float = 1e-6 dtype: jnp.dtype = jnp.float32 # the dtype of the computation bias: bool = True # If True, bias (beta) is added. @@ -106,7 +107,10 @@ class FlaxBertLayerNorm(nn.Module): scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - @nn.compact + def setup(self): + self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,)) + self.beta = self.param("beta", self.scale_init, (self.hidden_size,)) + def __call__(self, x): """ Applies layer normalization on the input. It normalizes the activations of the layer for each given example in @@ -119,18 +123,17 @@ class FlaxBertLayerNorm(nn.Module): Returns: Normalized inputs (the same shape as inputs). """ - features = x.shape[-1] mean = jnp.mean(x, axis=-1, keepdims=True) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) var = mean2 - jax.lax.square(mean) mul = jax.lax.rsqrt(var + self.epsilon) if self.scale: - mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,))) + mul = mul * jnp.asarray(self.gamma) y = (x - mean) * mul if self.bias: - y = y + jnp.asarray(self.param("beta", self.bias_init, (features,))) + y = y + jnp.asarray(self.beta) return y @@ -142,278 +145,232 @@ class FlaxBertEmbedding(nn.Module): vocab_size: int hidden_size: int - kernel_init_scale: float = 0.2 - emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale) + initializer_range: float dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, inputs): - embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) - return jnp.take(embedding, inputs, axis=0) + def setup(self): + init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range) + self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size)) + + def __call__(self, input_ids): + return jnp.take(self.embeddings, input_ids, axis=0) class FlaxBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" - vocab_size: int - hidden_size: int - type_vocab_size: int - max_length: int - kernel_init_scale: float = 0.2 - dropout_rate: float = 0.0 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - - # Embed - w_emb = FlaxBertEmbedding( - self.vocab_size, - self.hidden_size, - kernel_init_scale=self.kernel_init_scale, + def setup(self): + self.word_embeddings = FlaxBertEmbedding( + self.config.vocab_size, + self.config.hidden_size, + initializer_range=self.config.initializer_range, name="word_embeddings", dtype=self.dtype, - )(jnp.atleast_2d(input_ids.astype("i4"))) - p_emb = FlaxBertEmbedding( - self.max_length, - self.hidden_size, - kernel_init_scale=self.kernel_init_scale, + ) + self.position_embeddings = FlaxBertEmbedding( + self.config.max_position_embeddings, + self.config.hidden_size, + initializer_range=self.config.initializer_range, name="position_embeddings", dtype=self.dtype, - )(jnp.atleast_2d(position_ids.astype("i4"))) - t_emb = FlaxBertEmbedding( - self.type_vocab_size, - self.hidden_size, - kernel_init_scale=self.kernel_init_scale, + ) + self.token_type_embeddings = FlaxBertEmbedding( + self.config.type_vocab_size, + self.config.hidden_size, + initializer_range=self.config.initializer_range, name="token_type_embeddings", dtype=self.dtype, - )(jnp.atleast_2d(token_type_ids.astype("i4"))) + ) + self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4"))) + position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4"))) + token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4"))) # Sum all embeddings - summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb + hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings # Layer Norm - layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb) - embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic) - return embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states class FlaxBertAttention(nn.Module): - num_heads: int - head_size: int - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + def setup(self): + self.self_attention = nn.attention.SelfAttention( + num_heads=self.config.num_attention_heads, + qkv_features=self.config.hidden_size, + dropout_rate=self.config.attention_probs_dropout_prob, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + bias_init=jax.nn.initializers.zeros, + name="self", + dtype=self.dtype, + ) + self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype) + + def __call__(self, hidden_states, attention_mask, deterministic=True): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - self_att = nn.attention.SelfAttention( - num_heads=self.num_heads, - qkv_features=self.head_size, - dropout_rate=self.dropout_rate, - deterministic=deterministic, - kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), - bias_init=jax.nn.initializers.zeros, - name="self", - dtype=self.dtype, - )(hidden_states, attention_mask) + self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic) - layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states) - return layer_norm + hidden_states = self.layer_norm(self_attn_output + hidden_states) + return hidden_states class FlaxBertIntermediate(nn.Module): - output_size: int - hidden_act: str = "gelu" - kernel_init_scale: float = 0.2 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, hidden_states): - hidden_states = nn.Dense( - features=self.output_size, - kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), name="dense", dtype=self.dtype, - )(hidden_states) - hidden_states = ACT2FN[self.hidden_act](hidden_states) + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) return hidden_states class FlaxBertOutput(nn.Module): - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, intermediate_output, attention_output, deterministic: bool = True): - hidden_states = nn.Dense( - attention_output.shape[-1], - kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), name="dense", dtype=self.dtype, - )(intermediate_output) - hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic) - hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output) + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.layer_norm(hidden_states + attention_output) return hidden_states class FlaxBertLayer(nn.Module): - num_heads: int - head_size: int - intermediate_size: int - hidden_act: str = "gelu" - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - attention = FlaxBertAttention( - self.num_heads, - self.head_size, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="attention", - dtype=self.dtype, - )(hidden_states, attention_mask, deterministic=deterministic) - intermediate = FlaxBertIntermediate( - self.intermediate_size, - kernel_init_scale=self.kernel_init_scale, - hidden_act=self.hidden_act, - name="intermediate", - dtype=self.dtype, - )(attention) - output = FlaxBertOutput( - kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype - )(intermediate, attention, deterministic=deterministic) + def setup(self): + self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype) + self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype) + self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype) - return output + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + return hidden_states class FlaxBertLayerCollection(nn.Module): - """ - Stores N BertLayer(s) - """ - - num_layers: int - num_heads: int - head_size: int - intermediate_size: int - hidden_act: str = "gelu" - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, inputs, attention_mask, deterministic: bool = True): - assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" + def setup(self): + self.layers = [ + FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] - # Initialize input / output - input_i = inputs - - # Forward over all encoders - for i in range(self.num_layers): - layer = FlaxBertLayer( - self.num_heads, - self.head_size, - self.intermediate_size, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - hidden_act=self.hidden_act, - name=f"{i}", - dtype=self.dtype, - ) - input_i = layer(input_i, attention_mask, deterministic=deterministic) - return input_i + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) + return hidden_states class FlaxBertEncoder(nn.Module): - num_layers: int - num_heads: int - head_size: int - intermediate_size: int - hidden_act: str = "gelu" - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact + def setup(self): + self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype) + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - layer = FlaxBertLayerCollection( - self.num_layers, - self.num_heads, - self.head_size, - self.intermediate_size, - hidden_act=self.hidden_act, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="layer", - dtype=self.dtype, - )(hidden_states, attention_mask, deterministic=deterministic) - return layer + return self.layers(hidden_states, attention_mask, deterministic=deterministic) class FlaxBertPooler(nn.Module): - kernel_init_scale: float = 0.2 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, hidden_states): - cls_token = hidden_states[:, 0] - out = nn.Dense( - hidden_states.shape[-1], - kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), name="dense", dtype=self.dtype, - )(cls_token) - return nn.tanh(out) + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) class FlaxBertPredictionHeadTransform(nn.Module): - hidden_act: str = "gelu" + config: BertConfig dtype: jnp.dtype = jnp.float32 - @nn.compact + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype) + self.activation = ACT2FN[self.config.hidden_act] + self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype) + def __call__(self, hidden_states): - hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states) - hidden_states = ACT2FN[self.hidden_act](hidden_states) - return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return self.layer_norm(hidden_states) class FlaxBertLMPredictionHead(nn.Module): - vocab_size: int - hidden_act: str = "gelu" + config: BertConfig dtype: jnp.dtype = jnp.float32 - @nn.compact + def setup(self): + self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype) + def __call__(self, hidden_states): # TODO: The output weights are the same as the input embeddings, but there is # an output-only bias for each token. # Need a link between the two variables so that the bias is correctly # resized with `resize_token_embeddings` - - hidden_states = FlaxBertPredictionHeadTransform( - name="transform", hidden_act=self.hidden_act, dtype=self.dtype - )(hidden_states) - hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states) + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) return hidden_states class FlaxBertOnlyMLMHead(nn.Module): - vocab_size: int - hidden_act: str = "gelu" + config: BertConfig dtype: jnp.dtype = jnp.float32 - @nn.compact + def setup(self): + self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype) + def __call__(self, hidden_states): - hidden_states = FlaxBertLMPredictionHead( - vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype - )(hidden_states) + hidden_states = self.mlm_head(hidden_states) return hidden_states @@ -543,20 +500,7 @@ class FlaxBertModel(FlaxBertPreTrainedModel): def __init__( self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs ): - module = FlaxBertModule( - vocab_size=config.vocab_size, - hidden_size=config.hidden_size, - type_vocab_size=config.type_vocab_size, - max_length=config.max_position_embeddings, - num_encoder_layers=config.num_hidden_layers, - num_heads=config.num_attention_heads, - head_size=config.hidden_size, - intermediate_size=config.intermediate_size, - dropout_rate=config.hidden_dropout_prob, - hidden_act=config.hidden_act, - dtype=dtype, - **kwargs, - ) + module = FlaxBertModule(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) @@ -592,71 +536,34 @@ class FlaxBertModel(FlaxBertPreTrainedModel): class FlaxBertModule(nn.Module): - vocab_size: int - hidden_size: int - type_vocab_size: int - max_length: int - num_encoder_layers: int - num_heads: int - head_size: int - intermediate_size: int - hidden_act: str = "gelu" - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True - @nn.compact + def setup(self): + self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype) + self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype) + self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype) + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): - # Embedding - embeddings = FlaxBertEmbeddings( - self.vocab_size, - self.hidden_size, - self.type_vocab_size, - self.max_length, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="embeddings", - dtype=self.dtype, - )(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) - - # N stacked encoding layers - encoder = FlaxBertEncoder( - self.num_encoder_layers, - self.num_heads, - self.head_size, - self.intermediate_size, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - hidden_act=self.hidden_act, - name="encoder", - dtype=self.dtype, - )(embeddings, attention_mask, deterministic=deterministic) + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic) if not self.add_pooling_layer: - return encoder + return hidden_states - pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) - return encoder, pooled + pooled = self.pooler(hidden_states) + return hidden_states, pooled class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): def __init__( self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs ): - module = FlaxBertForMaskedLMModule( - vocab_size=config.vocab_size, - type_vocab_size=config.type_vocab_size, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - head_size=config.hidden_size, - num_heads=config.num_attention_heads, - num_encoder_layers=config.num_hidden_layers, - max_length=config.max_position_embeddings, - hidden_act=config.hidden_act, - **kwargs, - ) + module = FlaxBertForMaskedLMModule(config, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) @@ -691,43 +598,32 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): class FlaxBertForMaskedLMModule(nn.Module): - vocab_size: int - hidden_size: int - intermediate_size: int - head_size: int - num_heads: int - num_encoder_layers: int - type_vocab_size: int - max_length: int - hidden_act: str - dropout_rate: float = 0.0 + config: BertConfig dtype: jnp.dtype = jnp.float32 - @nn.compact + def setup(self): + self.encoder = FlaxBertModule( + config=self.config, + add_pooling_layer=False, + name="bert", + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.mlm_head = FlaxBertOnlyMLMHead( + config=self.config, + name="cls", + dtype=self.dtype, + ) + def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True ): # Model - encoder = FlaxBertModule( - vocab_size=self.vocab_size, - type_vocab_size=self.type_vocab_size, - hidden_size=self.hidden_size, - intermediate_size=self.intermediate_size, - head_size=self.hidden_size, - num_heads=self.num_heads, - num_encoder_layers=self.num_encoder_layers, - max_length=self.max_length, - dropout_rate=self.dropout_rate, - hidden_act=self.hidden_act, - dtype=self.dtype, - add_pooling_layer=False, - name="bert", - )(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + hidden_states = self.encoder( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) # Compute the prediction scores - encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic) - logits = FlaxBertOnlyMLMHead( - vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype - )(encoder) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.mlm_head(hidden_states) return (logits,) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 64fc2bdc4e..eeff923fcf 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -114,6 +114,7 @@ class FlaxRobertaLayerNorm(nn.Module): Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. """ + hidden_size: int epsilon: float = 1e-6 dtype: jnp.dtype = jnp.float32 # the dtype of the computation bias: bool = True # If True, bias (beta) is added. @@ -123,7 +124,10 @@ class FlaxRobertaLayerNorm(nn.Module): scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - @nn.compact + def setup(self): + self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,)) + self.beta = self.param("beta", self.scale_init, (self.hidden_size,)) + def __call__(self, x): """ Applies layer normalization on the input. It normalizes the activations of the layer for each given example in @@ -136,18 +140,17 @@ class FlaxRobertaLayerNorm(nn.Module): Returns: Normalized inputs (the same shape as inputs). """ - features = x.shape[-1] mean = jnp.mean(x, axis=-1, keepdims=True) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) var = mean2 - jax.lax.square(mean) mul = jax.lax.rsqrt(var + self.epsilon) if self.scale: - mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,))) + mul = mul * jnp.asarray(self.gamma) y = (x - mean) * mul if self.bias: - y = y + jnp.asarray(self.param("beta", self.bias_init, (features,))) + y = y + jnp.asarray(self.beta) return y @@ -160,243 +163,202 @@ class FlaxRobertaEmbedding(nn.Module): vocab_size: int hidden_size: int - kernel_init_scale: float = 0.2 - emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale) + initializer_range: float dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, inputs): - embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) - return jnp.take(embedding, inputs, axis=0) + def setup(self): + init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range) + self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size)) + + def __call__(self, input_ids): + return jnp.take(self.embeddings, input_ids, axis=0) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta class FlaxRobertaEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" - vocab_size: int - hidden_size: int - type_vocab_size: int - max_length: int - kernel_init_scale: float = 0.2 - dropout_rate: float = 0.0 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): - - # Embed - w_emb = FlaxRobertaEmbedding( - self.vocab_size, - self.hidden_size, - kernel_init_scale=self.kernel_init_scale, + def setup(self): + self.word_embeddings = FlaxRobertaEmbedding( + self.config.vocab_size, + self.config.hidden_size, + initializer_range=self.config.initializer_range, name="word_embeddings", dtype=self.dtype, - )(jnp.atleast_2d(input_ids.astype("i4"))) - p_emb = FlaxRobertaEmbedding( - self.max_length, - self.hidden_size, - kernel_init_scale=self.kernel_init_scale, + ) + self.position_embeddings = FlaxRobertaEmbedding( + self.config.max_position_embeddings, + self.config.hidden_size, + initializer_range=self.config.initializer_range, name="position_embeddings", dtype=self.dtype, - )(jnp.atleast_2d(position_ids.astype("i4"))) - t_emb = FlaxRobertaEmbedding( - self.type_vocab_size, - self.hidden_size, - kernel_init_scale=self.kernel_init_scale, + ) + self.token_type_embeddings = FlaxRobertaEmbedding( + self.config.type_vocab_size, + self.config.hidden_size, + initializer_range=self.config.initializer_range, name="token_type_embeddings", dtype=self.dtype, - )(jnp.atleast_2d(token_type_ids.astype("i4"))) + ) + self.layer_norm = FlaxRobertaLayerNorm( + hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4"))) + position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4"))) + token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4"))) # Sum all embeddings - summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb + hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings # Layer Norm - layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb) - embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic) - return embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta class FlaxRobertaAttention(nn.Module): - num_heads: int - head_size: int - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + def setup(self): + self.self_attention = nn.attention.SelfAttention( + num_heads=self.config.num_attention_heads, + qkv_features=self.config.hidden_size, + dropout_rate=self.config.attention_probs_dropout_prob, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + bias_init=jax.nn.initializers.zeros, + name="self", + dtype=self.dtype, + ) + self.layer_norm = FlaxRobertaLayerNorm( + hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype + ) + + def __call__(self, hidden_states, attention_mask, deterministic=True): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - self_att = nn.attention.SelfAttention( - num_heads=self.num_heads, - qkv_features=self.head_size, - dropout_rate=self.dropout_rate, - deterministic=deterministic, - kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), - bias_init=jax.nn.initializers.zeros, - name="self", - dtype=self.dtype, - )(hidden_states, attention_mask) + self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic) - layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states) - return layer_norm + hidden_states = self.layer_norm(self_attn_output + hidden_states) + return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta class FlaxRobertaIntermediate(nn.Module): - output_size: int - hidden_act: str = "gelu" - kernel_init_scale: float = 0.2 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, hidden_states): - hidden_states = nn.Dense( - features=self.output_size, - kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), name="dense", dtype=self.dtype, - )(hidden_states) - hidden_states = ACT2FN[self.hidden_act](hidden_states) + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta class FlaxRobertaOutput(nn.Module): - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, intermediate_output, attention_output, deterministic: bool = True): - hidden_states = nn.Dense( - attention_output.shape[-1], - kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), name="dense", dtype=self.dtype, - )(intermediate_output) - hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic) - hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output) + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.layer_norm = FlaxRobertaLayerNorm( + hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype + ) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.layer_norm(hidden_states + attention_output) return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta class FlaxRobertaLayer(nn.Module): - num_heads: int - head_size: int - intermediate_size: int - hidden_act: str = "gelu" - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - attention = FlaxRobertaAttention( - self.num_heads, - self.head_size, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="attention", - dtype=self.dtype, - )(hidden_states, attention_mask, deterministic=deterministic) - intermediate = FlaxRobertaIntermediate( - self.intermediate_size, - kernel_init_scale=self.kernel_init_scale, - hidden_act=self.hidden_act, - name="intermediate", - dtype=self.dtype, - )(attention) - output = FlaxRobertaOutput( - kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype - )(intermediate, attention, deterministic=deterministic) + def setup(self): + self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype) + self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype) + self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype) - return output + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta class FlaxRobertaLayerCollection(nn.Module): - """ - Stores N RobertaLayer(s) - """ - - num_layers: int - num_heads: int - head_size: int - intermediate_size: int - hidden_act: str = "gelu" - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, inputs, attention_mask, deterministic: bool = True): - assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" + def setup(self): + self.layers = [ + FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] - # Initialize input / output - input_i = inputs - - # Forward over all encoders - for i in range(self.num_layers): - layer = FlaxRobertaLayer( - self.num_heads, - self.head_size, - self.intermediate_size, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - hidden_act=self.hidden_act, - name=f"{i}", - dtype=self.dtype, - ) - input_i = layer(input_i, attention_mask, deterministic=deterministic) - return input_i + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) + return hidden_states # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta class FlaxRobertaEncoder(nn.Module): - num_layers: int - num_heads: int - head_size: int - intermediate_size: int - hidden_act: str = "gelu" - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact + def setup(self): + self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype) + def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - layer = FlaxRobertaLayerCollection( - self.num_layers, - self.num_heads, - self.head_size, - self.intermediate_size, - hidden_act=self.hidden_act, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="layer", - dtype=self.dtype, - )(hidden_states, attention_mask, deterministic=deterministic) - return layer + return self.layers(hidden_states, attention_mask, deterministic=deterministic) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta class FlaxRobertaPooler(nn.Module): - kernel_init_scale: float = 0.2 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - @nn.compact - def __call__(self, hidden_states): - cls_token = hidden_states[:, 0] - out = nn.Dense( - hidden_states.shape[-1], - kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), name="dense", dtype=self.dtype, - )(cls_token) - return nn.tanh(out) + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): @@ -520,21 +482,7 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel): dtype: jnp.dtype = jnp.float32, **kwargs ): - module = FlaxRobertaModule( - vocab_size=config.vocab_size, - hidden_size=config.hidden_size, - type_vocab_size=config.type_vocab_size, - max_length=config.max_position_embeddings, - num_encoder_layers=config.num_hidden_layers, - num_heads=config.num_attention_heads, - head_size=config.hidden_size, - hidden_act=config.hidden_act, - intermediate_size=config.intermediate_size, - dropout_rate=config.hidden_dropout_prob, - dtype=dtype, - **kwargs, - ) - + module = FlaxRobertaModule(config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @@ -570,50 +518,24 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel): # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta class FlaxRobertaModule(nn.Module): - vocab_size: int - hidden_size: int - type_vocab_size: int - max_length: int - num_encoder_layers: int - num_heads: int - head_size: int - intermediate_size: int - hidden_act: str = "gelu" - dropout_rate: float = 0.0 - kernel_init_scale: float = 0.2 + config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True - @nn.compact + def setup(self): + self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype) + self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype) + self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype) + def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): - # Embedding - embeddings = FlaxRobertaEmbeddings( - self.vocab_size, - self.hidden_size, - self.type_vocab_size, - self.max_length, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - name="embeddings", - dtype=self.dtype, - )(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic) - - # N stacked encoding layers - encoder = FlaxRobertaEncoder( - self.num_encoder_layers, - self.num_heads, - self.head_size, - self.intermediate_size, - kernel_init_scale=self.kernel_init_scale, - dropout_rate=self.dropout_rate, - hidden_act=self.hidden_act, - name="encoder", - dtype=self.dtype, - )(embeddings, attention_mask, deterministic=deterministic) + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic) if not self.add_pooling_layer: - return encoder + return hidden_states - pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder) - return encoder, pooled + pooled = self.pooler(hidden_states) + return hidden_states, pooled diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 19e900aef4..0b517a5f43 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None): return attn_mask +@require_flax class FlaxModelTesterMixin: model_tester = None all_model_classes = () @@ -90,7 +91,7 @@ class FlaxModelTesterMixin: fx_outputs = fx_model(**inputs_dict) self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) @@ -103,7 +104,6 @@ class FlaxModelTesterMixin: for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3) - @require_flax def test_from_pretrained_save_pretrained(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -121,7 +121,6 @@ class FlaxModelTesterMixin: for output_loaded, output in zip(outputs_loaded, outputs): self.assert_almost_equals(output_loaded, output, 5e-3) - @require_flax def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -144,7 +143,6 @@ class FlaxModelTesterMixin: for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) - @require_flax def test_naming_convention(self): for model_class in self.all_model_classes: model_class_name = model_class.__name__