Fix cross-attention head mask for Torch encoder-decoder models (#10605)

* Fix cross-attention head mask for Torch BART models

* Fix head masking for cross-attention module for the following
models: BART, Blenderbot, Blenderbot_small, M2M_100, Marian, MBart,
Pegasus

* Enable test_headmasking for M2M_100 model

* Fix cross_head_mask for FSMT, LED and T5

* This commit fixes `head_mask` for cross-attention modules
in the following models: FSMT, LED, T5

* It also contains some smaller changes in doc so that
it is be perfectly clear the shape of `cross_head_mask`
is the same as of `decoder_head_mask`

* Update template

* Fix template for BartForCausalLM

* Fix cross_head_mask for Speech2Text models

* Fix cross_head_mask in templates

* Fix args order in BartForCausalLM template

* Fix doc in BART templates

* Make more explicit naming

* `cross_head_mask` -> `cross_attn_head_mask`

* `cross_layer_head_mask` -> `cross_attn_layer_head_mask`

* Fix doc

* make style quality

* Fix speech2text docstring
This commit is contained in:
Daniel Stancl
2021-04-23 18:58:06 +02:00
committed by GitHub
parent ca6b80cadb
commit e3ff165aa5
23 changed files with 587 additions and 389 deletions

View File

@@ -296,7 +296,7 @@ class BartEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
@@ -368,7 +368,7 @@ class BartDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -382,9 +382,9 @@ class BartDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -419,7 +419,7 @@ class BartDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -598,18 +598,25 @@ BART_INPUTS_DOCSTRING = r"""
If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
information on the default strategy. information on the default strategy.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -710,11 +717,11 @@ class BartEncoder(BartPretrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -875,7 +882,7 @@ class BartDecoder(BartPretrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -912,18 +919,18 @@ class BartDecoder(BartPretrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -993,11 +1000,12 @@ class BartDecoder(BartPretrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -1031,7 +1039,7 @@ class BartDecoder(BartPretrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -1042,7 +1050,9 @@ class BartDecoder(BartPretrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1123,6 +1133,7 @@ class BartModel(BartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1172,7 +1183,7 @@ class BartModel(BartPretrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -1248,6 +1259,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1282,6 +1294,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1386,6 +1399,7 @@ class BartForSequenceClassification(BartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
@@ -1416,6 +1430,7 @@ class BartForSequenceClassification(BartPretrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1496,6 +1511,7 @@ class BartForQuestionAnswering(BartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
@@ -1527,6 +1543,7 @@ class BartForQuestionAnswering(BartPretrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1633,7 +1650,7 @@ class BartForCausalLM(BartPretrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
@@ -1666,18 +1683,17 @@ class BartForCausalLM(BartPretrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -1734,7 +1750,7 @@ class BartForCausalLM(BartPretrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -298,7 +298,7 @@ class BlenderbotEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
@@ -371,7 +371,7 @@ class BlenderbotDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -385,9 +385,9 @@ class BlenderbotDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -423,7 +423,7 @@ class BlenderbotDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -554,18 +554,25 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -666,11 +673,11 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -834,7 +841,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -871,18 +878,19 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -951,11 +959,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -989,7 +998,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -1000,7 +1009,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1090,6 +1101,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1147,7 +1159,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -1241,6 +1253,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1275,6 +1288,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1395,7 +1409,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
@@ -1428,18 +1442,17 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -1496,7 +1509,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -296,7 +296,7 @@ class BlenderbotSmallEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
@@ -369,7 +369,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -383,9 +383,9 @@ class BlenderbotSmallDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -420,7 +420,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -555,18 +555,25 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -667,11 +674,11 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -834,7 +841,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -871,18 +878,18 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -953,10 +960,12 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
if head_mask is not None: # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
assert head_mask.size()[0] == ( for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
len(self.layers) if attn_mask is not None:
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -990,7 +999,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -1001,7 +1010,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1077,6 +1088,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1134,7 +1146,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -1216,6 +1228,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1250,6 +1263,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1370,7 +1384,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
@@ -1403,18 +1417,17 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -1471,7 +1484,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -248,17 +248,25 @@ FSMT_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
@@ -573,7 +581,7 @@ class DecoderLayer(nn.Module):
layer_state=None, layer_state=None,
causal_mask=None, causal_mask=None,
layer_head_mask=None, layer_head_mask=None,
encoder_layer_head_mask=None, cross_attn_layer_head_mask=None,
decoder_padding_mask=None, decoder_padding_mask=None,
output_attentions=False, output_attentions=False,
): ):
@@ -604,7 +612,7 @@ class DecoderLayer(nn.Module):
key=encoder_hidden_states, key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask, key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state layer_state=layer_state, # mutates layer state
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
@@ -666,7 +674,7 @@ class FSMTDecoder(nn.Module):
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask, decoder_causal_mask,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
@@ -690,12 +698,11 @@ class FSMTDecoder(nn.Module):
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
Returns: Returns:
BaseModelOutputWithPast or tuple: BaseModelOutputWithPast or tuple:
@@ -732,10 +739,11 @@ class FSMTDecoder(nn.Module):
next_decoder_cache = [] next_decoder_cache = []
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -756,7 +764,7 @@ class FSMTDecoder(nn.Module):
layer_state=layer_state, layer_state=layer_state,
causal_mask=decoder_causal_mask, causal_mask=decoder_causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -1009,6 +1017,7 @@ class FSMTModel(PretrainedFSMTModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Tuple] = None, encoder_outputs: Optional[Tuple] = None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
@@ -1065,7 +1074,7 @@ class FSMTModel(PretrainedFSMTModel):
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask=causal_mask, decoder_causal_mask=causal_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
@@ -1143,6 +1152,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
labels=None, labels=None,
@@ -1173,6 +1183,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,

View File

@@ -901,7 +901,7 @@ class LEDEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
""" """
residual = hidden_states residual = hidden_states
attn_outputs = self.self_attn( attn_outputs = self.self_attn(
@@ -968,7 +968,7 @@ class LEDDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -982,9 +982,9 @@ class LEDDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(decoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`): Whether the base model outputs attentions. output_attentions (:obj:`bool`): Whether the base model outputs attentions.
This requires the attentions tensor to be reshaped in this function. This requires the attentions tensor to be reshaped in this function.
@@ -1018,7 +1018,7 @@ class LEDDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -1199,17 +1199,6 @@ class LEDSeq2SeqModelOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
last_hidden_state: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None
@@ -1221,8 +1210,6 @@ class LEDSeq2SeqModelOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
@@ -1278,17 +1265,6 @@ class LEDSeq2SeqLMOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
@@ -1301,8 +1277,6 @@ class LEDSeq2SeqLMOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
@@ -1358,17 +1332,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
@@ -1381,8 +1344,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
@@ -1440,17 +1401,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
@@ -1464,8 +1414,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
LED_START_DOCSTRING = r""" LED_START_DOCSTRING = r"""
@@ -1547,17 +1495,24 @@ LED_INPUTS_DOCSTRING = r"""
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -1730,7 +1685,7 @@ class LEDEncoder(LEDPreTrainedModel):
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
@@ -1914,7 +1869,7 @@ class LEDDecoder(LEDPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -1961,18 +1916,17 @@ class LEDDecoder(LEDPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -2052,11 +2006,12 @@ class LEDDecoder(LEDPreTrainedModel):
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -2090,7 +2045,7 @@ class LEDDecoder(LEDPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -2100,7 +2055,9 @@ class LEDDecoder(LEDPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -2180,6 +2137,7 @@ class LEDModel(LEDPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
@@ -2224,7 +2182,7 @@ class LEDModel(LEDPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -2306,6 +2264,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
@@ -2358,6 +2317,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -2463,6 +2423,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -2495,6 +2456,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -2571,6 +2533,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
start_positions=None, start_positions=None,
@@ -2604,6 +2567,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,

View File

@@ -367,7 +367,7 @@ class M2M100EncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
@@ -440,7 +440,7 @@ class M2M100DecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -454,9 +454,9 @@ class M2M100DecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -492,7 +492,7 @@ class M2M100DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -603,6 +603,24 @@ M2M_100_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -704,6 +722,12 @@ class M2M100Encoder(M2M100PreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@@ -841,7 +865,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -878,6 +902,19 @@ class M2M100Decoder(M2M100PreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
@@ -955,11 +992,12 @@ class M2M100Decoder(M2M100PreTrainedModel):
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -993,7 +1031,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -1004,7 +1042,9 @@ class M2M100Decoder(M2M100PreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1085,6 +1125,7 @@ class M2M100Model(M2M100PreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1126,7 +1167,7 @@ class M2M100Model(M2M100PreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -1201,6 +1242,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1249,6 +1291,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1281,7 +1324,14 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past is not None: if past is not None:
@@ -1293,6 +1343,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
"past_key_values": past, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }

View File

@@ -313,7 +313,7 @@ class MarianEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
@@ -386,7 +386,7 @@ class MarianDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -400,9 +400,9 @@ class MarianDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -437,7 +437,7 @@ class MarianDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -567,18 +567,25 @@ MARIAN_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -678,11 +685,11 @@ class MarianEncoder(MarianPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -842,7 +849,7 @@ class MarianDecoder(MarianPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -879,18 +886,18 @@ class MarianDecoder(MarianPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -959,11 +966,12 @@ class MarianDecoder(MarianPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -997,7 +1005,7 @@ class MarianDecoder(MarianPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -1008,7 +1016,9 @@ class MarianDecoder(MarianPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1084,6 +1094,7 @@ class MarianModel(MarianPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1142,7 +1153,7 @@ class MarianModel(MarianPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -1229,6 +1240,7 @@ class MarianMTModel(MarianPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1264,6 +1276,7 @@ class MarianMTModel(MarianPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1391,7 +1404,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
@@ -1424,18 +1437,17 @@ class MarianForCausalLM(MarianPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -1492,7 +1504,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -303,7 +303,7 @@ class MBartEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
@@ -375,7 +375,7 @@ class MBartDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -389,9 +389,9 @@ class MBartDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -427,7 +427,7 @@ class MBartDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -595,18 +595,25 @@ MBART_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -708,11 +715,11 @@ class MBartEncoder(MBartPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -877,7 +884,7 @@ class MBartDecoder(MBartPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -914,18 +921,18 @@ class MBartDecoder(MBartPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -995,11 +1002,12 @@ class MBartDecoder(MBartPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -1033,7 +1041,7 @@ class MBartDecoder(MBartPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -1044,7 +1052,9 @@ class MBartDecoder(MBartPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1127,6 +1137,7 @@ class MBartModel(MBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1173,7 +1184,7 @@ class MBartModel(MBartPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -1254,6 +1265,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1287,6 +1299,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1384,6 +1397,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
@@ -1414,6 +1428,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1495,6 +1510,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
@@ -1526,6 +1542,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1634,7 +1651,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
@@ -1667,18 +1684,17 @@ class MBartForCausalLM(MBartPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -1735,7 +1751,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -313,7 +313,7 @@ class PegasusEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
@@ -386,7 +386,7 @@ class PegasusDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -400,9 +400,9 @@ class PegasusDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -438,7 +438,7 @@ class PegasusDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -566,18 +566,25 @@ PEGASUS_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -679,11 +686,11 @@ class PegasusEncoder(PegasusPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -848,7 +855,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -885,18 +892,18 @@ class PegasusDecoder(PegasusPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -965,11 +972,12 @@ class PegasusDecoder(PegasusPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -1003,7 +1011,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -1014,7 +1022,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1092,6 +1102,7 @@ class PegasusModel(PegasusPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1150,7 +1161,7 @@ class PegasusModel(PegasusPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -1232,6 +1243,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1267,6 +1279,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -1390,7 +1403,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
@@ -1423,18 +1436,17 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -1491,7 +1503,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -451,7 +451,7 @@ class Speech2TextDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -465,9 +465,9 @@ class Speech2TextDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
:obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. :obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
:obj:`(config.encoder_attention_heads,)`. :obj:`(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size :obj:`(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -503,7 +503,7 @@ class Speech2TextDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -623,19 +623,29 @@ SPEECH_TO_TEXT_INPUTS_DOCSTRING = r"""
:obj:`past_key_values`). :obj:`past_key_values`).
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default. <<<<<<< HEAD
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
If you want to change padding behavior, you should read
:func:`modeling_speech_to_text._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the
paper <https://arxiv.org/abs/1910.13461>`__ for more information on the default strategy.
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -728,11 +738,11 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -884,7 +894,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -921,18 +931,18 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``: on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -1001,12 +1011,12 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -1039,7 +1049,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -1050,7 +1060,9 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1127,6 +1139,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
@@ -1166,7 +1179,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -1240,6 +1253,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
@@ -1296,6 +1310,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -607,7 +607,7 @@ class T5Block(nn.Module):
encoder_attention_mask=None, encoder_attention_mask=None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
layer_head_mask=None, layer_head_mask=None,
encoder_layer_head_mask=None, cross_attn_layer_head_mask=None,
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
@@ -661,7 +661,7 @@ class T5Block(nn.Module):
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
query_length=query_length, query_length=query_length,
use_cache=use_cache, use_cache=use_cache,
@@ -846,7 +846,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=None, encoder_attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -913,7 +913,7 @@ class T5Stack(T5PreTrainedModel):
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
@@ -925,7 +925,7 @@ class T5Stack(T5PreTrainedModel):
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
encoder_layer_head_mask = encoder_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
# Model parallel # Model parallel
if self.model_parallel: if self.model_parallel:
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
@@ -942,8 +942,8 @@ class T5Stack(T5PreTrainedModel):
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if layer_head_mask is not None: if layer_head_mask is not None:
layer_head_mask = layer_head_mask.to(hidden_states.device) layer_head_mask = layer_head_mask.to(hidden_states.device)
if encoder_layer_head_mask is not None: if cross_attn_layer_head_mask is not None:
encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device) cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@@ -955,7 +955,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
encoder_layer_head_mask=encoder_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
@@ -1082,12 +1082,19 @@ T5_INPUTS_DOCSTRING = r"""
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. in the decoder Mask values selected in ``[0, Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1]``: 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
@@ -1263,6 +1270,7 @@ class T5Model(T5PreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1338,7 +1346,7 @@ class T5Model(T5PreTrainedModel):
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
@@ -1451,6 +1459,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1551,7 +1560,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,

View File

@@ -1041,10 +1041,11 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None,
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
labels=None, labels=None,
use_cache=None, use_cache=None,
@@ -1876,7 +1877,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
@@ -1890,9 +1891,9 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
@@ -1927,7 +1928,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -2070,18 +2071,24 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
information on the default strategy. information on the default strategy.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
@@ -2211,10 +2218,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -2377,7 +2385,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
@@ -2414,18 +2422,17 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -2493,12 +2500,12 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
@@ -2529,7 +2536,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
@@ -2540,7 +2547,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -2621,6 +2628,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -2662,7 +2670,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
@@ -2743,6 +2751,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -2791,6 +2800,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@@ -3124,7 +3134,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
@@ -3157,18 +3167,17 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -3225,7 +3234,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -55,6 +55,7 @@ def prepare_bart_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
@@ -64,6 +65,8 @@ def prepare_bart_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -71,6 +74,7 @@ def prepare_bart_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -45,6 +45,7 @@ def prepare_blenderbot_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
@@ -54,6 +55,8 @@ def prepare_blenderbot_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -61,6 +64,7 @@ def prepare_blenderbot_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -50,6 +50,7 @@ def prepare_blenderbot_small_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
@@ -59,6 +60,8 @@ def prepare_blenderbot_small_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -66,6 +69,7 @@ def prepare_blenderbot_small_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -225,8 +225,8 @@ class ModelTesterMixin:
"decoder_attention_mask", "decoder_attention_mask",
] ]
expected_arg_names.extend( expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"] ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"] else ["encoder_outputs"]
) )
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@@ -492,6 +492,8 @@ class ModelTesterMixin:
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask inputs["decoder_head_mask"] = head_mask
if "cross_attn_head_mask" in arg_names:
inputs["cross_attn_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True) outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation # Test that we can get a gradient back for importance score computation
@@ -523,6 +525,7 @@ class ModelTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions) check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions) check_attentions_validity(outputs.decoder_attentions)
check_attentions_validity(outputs.cross_attentions)
else: else:
check_attentions_validity(outputs.attentions) check_attentions_validity(outputs.attentions)
@@ -1093,7 +1096,7 @@ class ModelTesterMixin:
# some params shouldn't be scattered by nn.DataParallel # some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present. # so just remove them if they are present.
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"] blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
for k in blacklist_non_batched_params: for k in blacklist_non_batched_params:
inputs_dict.pop(k, None) inputs_dict.pop(k, None)

View File

@@ -113,6 +113,7 @@ def prepare_fsmt_inputs_dict(
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
@@ -120,6 +121,8 @@ def prepare_fsmt_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,

View File

@@ -52,6 +52,7 @@ def prepare_led_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
@@ -61,6 +62,8 @@ def prepare_led_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -68,6 +71,7 @@ def prepare_led_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -41,16 +41,28 @@ def prepare_m2m_100_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
@@ -142,9 +154,10 @@ class M2M100ModelTester:
model = M2M100Model(config=config).get_decoder().to(torch_device).eval() model = M2M100Model(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"] attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
@@ -217,7 +230,6 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False
test_missing_keys = False test_missing_keys = False
def setUp(self): def setUp(self):

View File

@@ -60,6 +60,7 @@ def prepare_marian_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
@@ -69,6 +70,8 @@ def prepare_marian_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -76,6 +79,7 @@ def prepare_marian_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -52,6 +52,7 @@ def prepare_mbart_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
@@ -61,6 +62,8 @@ def prepare_mbart_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -68,6 +71,7 @@ def prepare_mbart_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -42,6 +42,7 @@ def prepare_pegasus_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
@@ -51,6 +52,8 @@ def prepare_pegasus_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -58,6 +61,7 @@ def prepare_pegasus_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -55,17 +55,29 @@ def prepare_speech_to_text_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_features.ne(0) attention_mask = input_features.ne(0)
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
# "input_ids": input_features, # "input_ids": input_features,
"input_features": input_features, "input_features": input_features,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
@@ -247,7 +259,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False
test_missing_keys = False test_missing_keys = False
test_torchscript = True test_torchscript = True
@@ -316,8 +327,8 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
"decoder_attention_mask", "decoder_attention_mask",
] ]
expected_arg_names.extend( expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"] ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"] else ["encoder_outputs"]
) )
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)