Fix cross-attention head mask for Torch encoder-decoder models (#10605)
* Fix cross-attention head mask for Torch BART models * Fix head masking for cross-attention module for the following models: BART, Blenderbot, Blenderbot_small, M2M_100, Marian, MBart, Pegasus * Enable test_headmasking for M2M_100 model * Fix cross_head_mask for FSMT, LED and T5 * This commit fixes `head_mask` for cross-attention modules in the following models: FSMT, LED, T5 * It also contains some smaller changes in doc so that it is be perfectly clear the shape of `cross_head_mask` is the same as of `decoder_head_mask` * Update template * Fix template for BartForCausalLM * Fix cross_head_mask for Speech2Text models * Fix cross_head_mask in templates * Fix args order in BartForCausalLM template * Fix doc in BART templates * Make more explicit naming * `cross_head_mask` -> `cross_attn_head_mask` * `cross_layer_head_mask` -> `cross_attn_layer_head_mask` * Fix doc * make style quality * Fix speech2text docstring
This commit is contained in:
@@ -1041,10 +1041,11 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
@@ -1876,7 +1877,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@@ -1890,9 +1891,9 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@@ -1927,7 +1928,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
layer_head_mask=cross_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -2070,18 +2071,24 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and
|
||||
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||
information on the default strategy.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@@ -2211,10 +2218,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
@@ -2377,7 +2385,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@@ -2414,18 +2422,17 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@@ -2493,12 +2500,12 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@@ -2529,7 +2536,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@@ -2540,7 +2547,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@@ -2621,6 +2628,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -2662,7 +2670,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@@ -2743,6 +2751,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -2791,6 +2800,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@@ -3124,7 +3134,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
@@ -3157,18 +3167,17 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@@ -3225,7 +3234,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
||||
Reference in New Issue
Block a user