Implement head_mask for Flax BERT and other models copied from BERT (#14620)
* Implement head_mask for Flax BERT and other models copied from BERT * Remove `from jax._src.nn.functions import sigmoid` Remove `from jax._src.nn.functions import sigmoid` unintentionally added by IDE * Remove no more valid copy statement * Apply patil-suraj's suggestions from code review * Apply suggestions from the code review * Update Flax template * Fix a typo * Also update template for CausalLM modules
This commit is contained in:
@@ -161,6 +161,12 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||||
config.max_position_embeddings - 1]``.
|
config.max_position_embeddings - 1]``.
|
||||||
|
head_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
@@ -234,7 +240,14 @@ class FlaxBertSelfAttention(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
query_states = self.query(hidden_states).reshape(
|
query_states = self.query(hidden_states).reshape(
|
||||||
@@ -275,6 +288,10 @@ class FlaxBertSelfAttention(nn.Module):
|
|||||||
precision=None,
|
precision=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
||||||
|
|
||||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||||
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||||
|
|
||||||
@@ -310,12 +327,23 @@ class FlaxBertAttention(nn.Module):
|
|||||||
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
|
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attn_outputs = self.self(
|
attn_outputs = self.self(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
@@ -375,9 +403,20 @@ class FlaxBertLayer(nn.Module):
|
|||||||
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
@@ -404,6 +443,7 @@ class FlaxBertLayerCollection(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -412,12 +452,24 @@ class FlaxBertLayerCollection(nn.Module):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# Check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.shape[0] != (len(self.layers)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
|
||||||
|
{head_mask.shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@@ -449,6 +501,7 @@ class FlaxBertEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -457,6 +510,7 @@ class FlaxBertEncoder(nn.Module):
|
|||||||
return self.layer(
|
return self.layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -577,13 +631,14 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
return self.module.init(
|
||||||
"params"
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
]
|
)["params"]
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -592,6 +647,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: jax.random.PRNGKey = None,
|
dropout_rng: jax.random.PRNGKey = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
@@ -615,6 +671,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
# Handle any PRNG if needed
|
# Handle any PRNG if needed
|
||||||
rngs = {}
|
rngs = {}
|
||||||
if dropout_rng is not None:
|
if dropout_rng is not None:
|
||||||
@@ -626,6 +685,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
jnp.array(head_mask, dtype="i4"),
|
||||||
not train,
|
not train,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@@ -650,6 +710,7 @@ class FlaxBertModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids: Optional[np.ndarray] = None,
|
token_type_ids: Optional[np.ndarray] = None,
|
||||||
position_ids: Optional[np.ndarray] = None,
|
position_ids: Optional[np.ndarray] = None,
|
||||||
|
head_mask: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -669,6 +730,7 @@ class FlaxBertModule(nn.Module):
|
|||||||
outputs = self.encoder(
|
outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -718,6 +780,7 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -730,6 +793,7 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -810,6 +874,7 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -821,6 +886,7 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -870,6 +936,7 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -883,6 +950,7 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -962,6 +1030,7 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -973,6 +1042,7 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1028,6 +1098,7 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1045,6 +1116,7 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1106,6 +1178,7 @@ class FlaxBertForTokenClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1117,6 +1190,7 @@ class FlaxBertForTokenClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1167,6 +1241,7 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1178,6 +1253,7 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
|||||||
@@ -178,6 +178,12 @@ BIG_BIRD_INPUTS_DOCSTRING = r"""
|
|||||||
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||||
config.max_position_embeddings - 1]``.
|
config.max_position_embeddings - 1]``.
|
||||||
|
head_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
@@ -256,7 +262,14 @@ class FlaxBigBirdSelfAttention(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
query_states = self.query(hidden_states).reshape(
|
query_states = self.query(hidden_states).reshape(
|
||||||
@@ -297,6 +310,10 @@ class FlaxBigBirdSelfAttention(nn.Module):
|
|||||||
precision=None,
|
precision=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
||||||
|
|
||||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||||
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||||
|
|
||||||
@@ -1113,14 +1130,32 @@ class FlaxBigBirdAttention(nn.Module):
|
|||||||
|
|
||||||
self.output = FlaxBigBirdSelfOutput(self.config, dtype=self.dtype)
|
self.output = FlaxBigBirdSelfOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention.__call__ with Bert->BigBird
|
def __call__(
|
||||||
def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions: bool = False):
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attn_outputs = self.self(
|
if self.config.attention_type == "original_full":
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
attn_outputs = self.self(
|
||||||
)
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_outputs = self.self(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
|
|
||||||
@@ -1183,9 +1218,20 @@ class FlaxBigBirdLayer(nn.Module):
|
|||||||
self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype)
|
self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
@@ -1214,6 +1260,7 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1222,12 +1269,24 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# Check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.shape[0] != (len(self.layers)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
|
||||||
|
{head_mask.shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@@ -1260,6 +1319,7 @@ class FlaxBigBirdEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1268,6 +1328,7 @@ class FlaxBigBirdEncoder(nn.Module):
|
|||||||
return self.layer(
|
return self.layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1374,13 +1435,14 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
return self.module.init(
|
||||||
"params"
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
]
|
)["params"]
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -1389,6 +1451,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: jax.random.PRNGKey = None,
|
dropout_rng: jax.random.PRNGKey = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
@@ -1412,6 +1475,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
# Handle any PRNG if needed
|
# Handle any PRNG if needed
|
||||||
rngs = {}
|
rngs = {}
|
||||||
if dropout_rng is not None:
|
if dropout_rng is not None:
|
||||||
@@ -1423,6 +1489,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
jnp.array(head_mask, dtype="i4"),
|
||||||
not train,
|
not train,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@@ -1451,6 +1518,7 @@ class FlaxBigBirdModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1462,6 +1530,7 @@ class FlaxBigBirdModule(nn.Module):
|
|||||||
outputs = self.encoder(
|
outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1514,6 +1583,7 @@ class FlaxBigBirdForPreTrainingModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1526,6 +1596,7 @@ class FlaxBigBirdForPreTrainingModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1608,6 +1679,7 @@ class FlaxBigBirdForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1619,6 +1691,7 @@ class FlaxBigBirdForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1695,6 +1768,7 @@ class FlaxBigBirdForSequenceClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1706,6 +1780,7 @@ class FlaxBigBirdForSequenceClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1762,6 +1837,7 @@ class FlaxBigBirdForMultipleChoiceModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1779,6 +1855,7 @@ class FlaxBigBirdForMultipleChoiceModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1859,6 +1936,7 @@ class FlaxBigBirdForTokenClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1870,6 +1948,7 @@ class FlaxBigBirdForTokenClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1945,6 +2024,7 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
logits_mask=None,
|
logits_mask=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
@@ -1958,6 +2038,7 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -2005,6 +2086,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
question_lengths=None,
|
question_lengths=None,
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: jax.random.PRNGKey = None,
|
dropout_rng: jax.random.PRNGKey = None,
|
||||||
@@ -2025,6 +2107,9 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
if question_lengths is None and input_ids is not None:
|
if question_lengths is None and input_ids is not None:
|
||||||
# assuming input_ids format: <cls> <question> <sep> context <sep>
|
# assuming input_ids format: <cls> <question> <sep> context <sep>
|
||||||
question_lengths = jnp.argmax((input_ids == self.config.sep_token_id).astype("i4"), axis=-1) + 1
|
question_lengths = jnp.argmax((input_ids == self.config.sep_token_id).astype("i4"), axis=-1) + 1
|
||||||
@@ -2056,6 +2141,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
|
|||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
jnp.array(head_mask, dtype="i4"),
|
||||||
logits_mask,
|
logits_mask,
|
||||||
not train,
|
not train,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
|||||||
@@ -131,6 +131,12 @@ ELECTRA_INPUTS_DOCSTRING = r"""
|
|||||||
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||||
config.max_position_embeddings - 1]``.
|
config.max_position_embeddings - 1]``.
|
||||||
|
head_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
@@ -206,7 +212,14 @@ class FlaxElectraSelfAttention(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
query_states = self.query(hidden_states).reshape(
|
query_states = self.query(hidden_states).reshape(
|
||||||
@@ -247,6 +260,10 @@ class FlaxElectraSelfAttention(nn.Module):
|
|||||||
precision=None,
|
precision=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
||||||
|
|
||||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||||
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||||
|
|
||||||
@@ -284,12 +301,23 @@ class FlaxElectraAttention(nn.Module):
|
|||||||
self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype)
|
self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype)
|
self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attn_outputs = self.self(
|
attn_outputs = self.self(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
@@ -352,9 +380,20 @@ class FlaxElectraLayer(nn.Module):
|
|||||||
self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
|
self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
|
self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
@@ -382,6 +421,7 @@ class FlaxElectraLayerCollection(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -390,12 +430,24 @@ class FlaxElectraLayerCollection(nn.Module):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# Check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.shape[0] != (len(self.layers)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
|
||||||
|
{head_mask.shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@@ -428,6 +480,7 @@ class FlaxElectraEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -436,6 +489,7 @@ class FlaxElectraEncoder(nn.Module):
|
|||||||
return self.layer(
|
return self.layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -502,13 +556,14 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
return self.module.init(
|
||||||
"params"
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
]
|
)["params"]
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -517,6 +572,7 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: PRNGKey = None,
|
dropout_rng: PRNGKey = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
@@ -541,6 +597,9 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
# Handle any PRNG if needed
|
# Handle any PRNG if needed
|
||||||
rngs = {}
|
rngs = {}
|
||||||
if dropout_rng is not None:
|
if dropout_rng is not None:
|
||||||
@@ -552,6 +611,7 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
jnp.array(head_mask, dtype="i4"),
|
||||||
not train,
|
not train,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@@ -576,6 +636,7 @@ class FlaxElectraModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -590,6 +651,7 @@ class FlaxElectraModule(nn.Module):
|
|||||||
return self.encoder(
|
return self.encoder(
|
||||||
embeddings,
|
embeddings,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -650,6 +712,7 @@ class FlaxElectraForMaskedLMModule(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -660,6 +723,7 @@ class FlaxElectraForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -708,6 +772,7 @@ class FlaxElectraForPreTrainingModule(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -719,6 +784,7 @@ class FlaxElectraForPreTrainingModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -795,6 +861,7 @@ class FlaxElectraForTokenClassificationModule(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -806,6 +873,7 @@ class FlaxElectraForTokenClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -935,6 +1003,7 @@ class FlaxElectraForMultipleChoiceModule(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -952,6 +1021,7 @@ class FlaxElectraForMultipleChoiceModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1011,6 +1081,7 @@ class FlaxElectraForQuestionAnsweringModule(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1022,6 +1093,7 @@ class FlaxElectraForQuestionAnsweringModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1104,6 +1176,7 @@ class FlaxElectraForSequenceClassificationModule(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1115,6 +1188,7 @@ class FlaxElectraForSequenceClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
|||||||
@@ -122,6 +122,12 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
|||||||
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||||
config.max_position_embeddings - 1]``.
|
config.max_position_embeddings - 1]``.
|
||||||
|
head_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -196,7 +202,14 @@ class FlaxRobertaSelfAttention(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
query_states = self.query(hidden_states).reshape(
|
query_states = self.query(hidden_states).reshape(
|
||||||
@@ -237,6 +250,10 @@ class FlaxRobertaSelfAttention(nn.Module):
|
|||||||
precision=None,
|
precision=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
||||||
|
|
||||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||||
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||||
|
|
||||||
@@ -274,12 +291,23 @@ class FlaxRobertaAttention(nn.Module):
|
|||||||
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
|
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attn_outputs = self.self(
|
attn_outputs = self.self(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
@@ -342,9 +370,20 @@ class FlaxRobertaLayer(nn.Module):
|
|||||||
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
@@ -372,6 +411,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -380,12 +420,24 @@ class FlaxRobertaLayerCollection(nn.Module):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# Check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.shape[0] != (len(self.layers)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
|
||||||
|
{head_mask.shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@@ -418,6 +470,7 @@ class FlaxRobertaEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -426,6 +479,7 @@ class FlaxRobertaEncoder(nn.Module):
|
|||||||
return self.layer(
|
return self.layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -546,13 +600,14 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
token_type_ids = jnp.ones_like(input_ids)
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
return self.module.init(
|
||||||
"params"
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
]
|
)["params"]
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -561,6 +616,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: PRNGKey = None,
|
dropout_rng: PRNGKey = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
@@ -584,6 +640,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
# Handle any PRNG if needed
|
# Handle any PRNG if needed
|
||||||
rngs = {}
|
rngs = {}
|
||||||
if dropout_rng is not None:
|
if dropout_rng is not None:
|
||||||
@@ -595,6 +654,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
jnp.array(head_mask, dtype="i4"),
|
||||||
not train,
|
not train,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@@ -620,6 +680,7 @@ class FlaxRobertaModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids: Optional[np.ndarray] = None,
|
token_type_ids: Optional[np.ndarray] = None,
|
||||||
position_ids: Optional[np.ndarray] = None,
|
position_ids: Optional[np.ndarray] = None,
|
||||||
|
head_mask: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -639,6 +700,7 @@ class FlaxRobertaModule(nn.Module):
|
|||||||
outputs = self.encoder(
|
outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -688,6 +750,7 @@ class FlaxRobertaForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -699,6 +762,7 @@ class FlaxRobertaForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -753,6 +817,7 @@ class FlaxRobertaForSequenceClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -764,6 +829,7 @@ class FlaxRobertaForSequenceClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -819,6 +885,7 @@ class FlaxRobertaForMultipleChoiceModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -836,6 +903,7 @@ class FlaxRobertaForMultipleChoiceModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -902,6 +970,7 @@ class FlaxRobertaForTokenClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -913,6 +982,7 @@ class FlaxRobertaForTokenClassificationModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -968,6 +1038,7 @@ class FlaxRobertaForQuestionAnsweringModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -979,6 +1050,7 @@ class FlaxRobertaForQuestionAnsweringModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
|||||||
@@ -116,6 +116,12 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
|||||||
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||||
config.max_position_embeddings - 1]``.
|
config.max_position_embeddings - 1]``.
|
||||||
|
head_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
|
||||||
@@ -192,7 +198,14 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False
|
||||||
|
):
|
||||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
query_states = self.query(hidden_states).reshape(
|
query_states = self.query(hidden_states).reshape(
|
||||||
@@ -233,6 +246,10 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||||||
precision=None,
|
precision=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
||||||
|
|
||||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||||
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||||
|
|
||||||
@@ -270,12 +287,23 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
|||||||
self.self = Flax{{cookiecutter.camelcase_modelname}}SelfAttention(self.config, dtype=self.dtype)
|
self.self = Flax{{cookiecutter.camelcase_modelname}}SelfAttention(self.config, dtype=self.dtype)
|
||||||
self.output = Flax{{cookiecutter.camelcase_modelname}}SelfOutput(self.config, dtype=self.dtype)
|
self.output = Flax{{cookiecutter.camelcase_modelname}}SelfOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic=True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attn_outputs = self.self(
|
attn_outputs = self.self(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
@@ -338,9 +366,20 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
|||||||
self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype)
|
self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype)
|
||||||
self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype)
|
self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
@@ -368,6 +407,7 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -376,12 +416,24 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# Check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.shape[0] != (len(self.layers)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
|
||||||
|
{head_mask.shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@@ -414,6 +466,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -422,6 +475,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||||||
return self.layer(
|
return self.layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -547,13 +601,14 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
return self.module.init(
|
||||||
"params"
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
]
|
)["params"]
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -562,6 +617,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: jax.random.PRNGKey = None,
|
dropout_rng: jax.random.PRNGKey = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
@@ -585,6 +641,9 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
||||||
|
|
||||||
# Handle any PRNG if needed
|
# Handle any PRNG if needed
|
||||||
rngs = {}
|
rngs = {}
|
||||||
if dropout_rng is not None:
|
if dropout_rng is not None:
|
||||||
@@ -596,6 +655,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
jnp.array(head_mask, dtype="i4"),
|
||||||
not train,
|
not train,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@@ -620,6 +680,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -631,6 +692,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
|||||||
outputs = self.encoder(
|
outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -674,6 +736,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -685,6 +748,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -733,6 +797,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -744,6 +809,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -797,6 +863,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -808,6 +875,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -863,6 +931,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module)
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -880,6 +949,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module)
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -936,6 +1006,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Mo
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -947,6 +1018,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Mo
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -997,6 +1069,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
@@ -1008,6 +1081,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
head_mask,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
|||||||
@@ -118,6 +118,8 @@ class FlaxBertModelTester(unittest.TestCase):
|
|||||||
@require_flax
|
@require_flax
|
||||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
test_head_masking = True
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
FlaxBertModel,
|
FlaxBertModel,
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class FlaxModelTesterMixin:
|
|||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
test_mismatched_shapes = True
|
test_mismatched_shapes = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
test_head_masking = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class):
|
def _prepare_for_class(self, inputs_dict, model_class):
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
@@ -777,6 +778,53 @@ class FlaxModelTesterMixin:
|
|||||||
for name, type_ in types.items():
|
for name, type_ in types.items():
|
||||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||||
|
|
||||||
|
def test_headmasking(self):
|
||||||
|
if not self.test_head_masking:
|
||||||
|
return
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
def _prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
|
||||||
|
if i == 0:
|
||||||
|
return np.concatenate([np.zeros(1, dtype=jnp.int32), np.ones(attention_heads - 1, dtype=jnp.int32)])
|
||||||
|
if i == num_hidden_layers - 1:
|
||||||
|
return np.concatenate([np.zeros(attention_heads - 1, dtype=jnp.int32), np.ones(1, dtype=jnp.int32)])
|
||||||
|
return np.ones(attention_heads, dtype=jnp.int32)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = False
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||||
|
# Prepare head mask
|
||||||
|
inputs["head_mask"] = np.stack(
|
||||||
|
[
|
||||||
|
_prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
def _check_attentions_validity(attentions):
|
||||||
|
# Remove NaN
|
||||||
|
for t in attentions:
|
||||||
|
# Check we don't have more than 25% nans (arbitrary)
|
||||||
|
self.assertLess(np.isnan(t).sum(), t.size / 4)
|
||||||
|
attentions = [np.where(np.isnan(t), 0.0, t) for t in attentions]
|
||||||
|
|
||||||
|
self.assertAlmostEqual(attentions[0][..., 0, :, :].sum(), 0.0)
|
||||||
|
self.assertNotEqual(attentions[0][..., -1, :, :].sum(), 0.0)
|
||||||
|
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
|
||||||
|
self.assertNotEqual(attentions[1][..., 0, :, :].sum(), 0.0)
|
||||||
|
self.assertAlmostEqual(attentions[-1][..., -2, :, :].sum(), 0.0)
|
||||||
|
self.assertNotEqual(attentions[-1][..., -1, :, :].sum(), 0.0)
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
raise NotImplementedError("The test has not been implemented for encoder-decoder models yet.")
|
||||||
|
else:
|
||||||
|
_check_attentions_validity(outputs.attentions)
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
@@ -105,6 +105,8 @@ class FlaxElectraModelTester(unittest.TestCase):
|
|||||||
@require_flax
|
@require_flax
|
||||||
class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
test_head_masking = True
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
FlaxElectraModel,
|
FlaxElectraModel,
|
||||||
|
|||||||
@@ -116,6 +116,8 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
|||||||
@require_flax
|
@require_flax
|
||||||
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
test_head_masking = True
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
FlaxRobertaModel,
|
FlaxRobertaModel,
|
||||||
|
|||||||
Reference in New Issue
Block a user