Implement head_mask for Flax BERT and other models copied from BERT (#14620)

* Implement head_mask for Flax BERT and other models copied from BERT

* Remove `from jax._src.nn.functions import sigmoid`

Remove `from jax._src.nn.functions import sigmoid` unintentionally added by IDE

* Remove no more valid copy statement

* Apply patil-suraj's suggestions from code review

* Apply suggestions from the code review

* Update Flax template

* Fix a typo

* Also update template for CausalLM modules
This commit is contained in:
Daniel Stancl
2021-12-17 17:06:59 +01:00
committed by GitHub
parent 95119ad7b0
commit ff066119ca
9 changed files with 484 additions and 48 deletions

View File

@@ -116,6 +116,12 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
head_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
@@ -192,7 +198,14 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
deterministic=True,
output_attentions: bool = False
):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
@@ -233,6 +246,10 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
precision=None,
)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
@@ -270,12 +287,23 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
self.self = Flax{{cookiecutter.camelcase_modelname}}SelfAttention(self.config, dtype=self.dtype)
self.output = Flax{{cookiecutter.camelcase_modelname}}SelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
deterministic=True,
output_attentions: bool = False,
):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attn_outputs = self.self(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
@@ -338,9 +366,20 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype)
self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
deterministic: bool = True,
output_attentions: bool = False,
):
attention_outputs = self.attention(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
@@ -368,6 +407,7 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
self,
hidden_states,
attention_mask,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -376,12 +416,24 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
{head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
deterministic=deterministic,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@@ -414,6 +466,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
self,
hidden_states,
attention_mask,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -422,6 +475,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
return self.layer(
hidden_states,
attention_mask,
head_mask=head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -547,13 +601,14 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
token_type_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids)
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
"params"
]
return self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)["params"]
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
@@ -562,6 +617,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
@@ -585,6 +641,9 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if head_mask is None:
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
@@ -596,6 +655,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(head_mask, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
@@ -620,6 +680,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -631,6 +692,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
outputs = self.encoder(
hidden_states,
attention_mask,
head_mask=head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -674,6 +736,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -685,6 +748,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -733,6 +797,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -744,6 +809,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -797,6 +863,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -808,6 +875,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -863,6 +931,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module)
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -880,6 +949,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module)
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -936,6 +1006,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Mo
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -947,6 +1018,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Mo
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -997,6 +1069,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -1008,6 +1081,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,