Add head_mask/decoder_head_mask for BART (#9569)
* Add head_mask/decoder_head_mask for BART
This branch implement head_mask and decoder_head_mask
for BART-based models. Full list below:
- BART
- MBart
- Blenderbot
- BlenderbotSmall
- Marian
- Pegasus
Everything is accompanied with updated testing.
* Fix test_headmasking for BART models
* Fix text_headmasking for BART-like models
which has only 2 layers in each modules.
The condition
```
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
```
is, therefore, invalid for encoder-decoder models considering
the `head_mask`
```
head_mask = torch.ones(
self.model_tester.num_hidden_layers,
self.model_tester.num_attention_heads,
device=torch_device,
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0
```
specified in the `test_headmasking` test/function.
* Adjust test_modeling_common.py to reflect T5 input args
* Update tests/test_modeling_common.py
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
* Apply suggestions from code review
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* make style
* make fix-copies
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -160,6 +160,7 @@ class BartAttention(nn.Module):
|
|||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@@ -227,6 +228,13 @@ class BartAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit akward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
@@ -275,19 +283,30 @@ class BartEncoderLayer(nn.Module):
|
|||||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_head_mask: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, attn_weights, _ = self.self_attn(
|
hidden_states, attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -346,6 +365,8 @@ class BartDecoderLayer(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
@@ -358,6 +379,10 @@ class BartDecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||||
|
size `(config.encoder_attention_heads,)`.
|
||||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
@@ -373,6 +398,7 @@ class BartDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -391,6 +417,7 @@ class BartDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=encoder_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -568,6 +595,18 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and
|
If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and
|
||||||
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||||
information on the default strategy.
|
information on the default strategy.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
@@ -645,6 +684,7 @@ class BartEncoder(BartPretrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -668,6 +708,12 @@ class BartEncoder(BartPretrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -714,7 +760,13 @@ class BartEncoder(BartPretrainedModel):
|
|||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for encoder_layer in self.layers:
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@@ -734,9 +786,15 @@ class BartEncoder(BartPretrainedModel):
|
|||||||
create_custom_forward(encoder_layer),
|
create_custom_forward(encoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
(head_mask[idx] if head_mask is not None else None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -791,6 +849,8 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -827,6 +887,19 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -904,6 +977,12 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -933,6 +1012,8 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
combined_attention_mask,
|
combined_attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -942,6 +1023,8 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1018,6 +1101,8 @@ class BartModel(BartPretrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1046,6 +1131,7 @@ class BartModel(BartPretrainedModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1065,6 +1151,8 @@ class BartModel(BartPretrainedModel):
|
|||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1143,6 +1231,8 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1175,6 +1265,8 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -1207,7 +1299,14 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past is not None:
|
||||||
@@ -1219,6 +1318,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1278,6 +1378,8 @@ class BartForSequenceClassification(BartPretrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds=None,
|
||||||
@@ -1306,6 +1408,8 @@ class BartForSequenceClassification(BartPretrainedModel):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -1379,6 +1483,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
start_positions=None,
|
start_positions=None,
|
||||||
end_positions=None,
|
end_positions=None,
|
||||||
@@ -1408,6 +1514,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ class BlenderbotAttention(nn.Module):
|
|||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@@ -227,6 +228,13 @@ class BlenderbotAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit akward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
@@ -276,12 +284,20 @@ class BlenderbotEncoderLayer(nn.Module):
|
|||||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_head_mask: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
@@ -289,7 +305,10 @@ class BlenderbotEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
hidden_states, attn_weights, _ = self.self_attn(
|
hidden_states, attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -348,6 +367,8 @@ class BlenderbotDecoderLayer(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
@@ -360,6 +381,10 @@ class BlenderbotDecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||||
|
size `(config.encoder_attention_heads,)`.
|
||||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
@@ -376,6 +401,7 @@ class BlenderbotDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -394,6 +420,7 @@ class BlenderbotDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -528,6 +555,18 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
|
|||||||
If you want to change padding behavior, you should read :func:`modeling_blenderbot._prepare_decoder_inputs`
|
If you want to change padding behavior, you should read :func:`modeling_blenderbot._prepare_decoder_inputs`
|
||||||
and modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
and modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||||
information on the default strategy.
|
information on the default strategy.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
@@ -605,6 +644,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -628,6 +668,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -673,7 +719,13 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
|||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for encoder_layer in self.layers:
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@@ -693,9 +745,15 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
|||||||
create_custom_forward(encoder_layer),
|
create_custom_forward(encoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
(head_mask[idx] if head_mask is not None else None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -751,8 +809,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -778,6 +838,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
of the decoder.
|
of the decoder.
|
||||||
@@ -789,6 +855,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -866,6 +939,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -895,6 +974,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
combined_attention_mask,
|
combined_attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -902,8 +983,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -989,6 +1072,8 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1025,6 +1110,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1044,6 +1130,8 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
|||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1135,6 +1223,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1167,6 +1257,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -1199,7 +1291,14 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past is not None:
|
||||||
@@ -1211,6 +1310,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ class BlenderbotSmallAttention(nn.Module):
|
|||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@@ -225,6 +226,13 @@ class BlenderbotSmallAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit akward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
@@ -274,19 +282,30 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
|||||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_head_mask: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, attn_weights, _ = self.self_attn(
|
hidden_states, attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -346,6 +365,8 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
@@ -358,6 +379,10 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||||
|
size `(config.encoder_attention_heads,)`.
|
||||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
@@ -373,6 +398,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -391,6 +417,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=encoder_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -529,6 +556,18 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
|
|||||||
If you want to change padding behavior, you should read
|
If you want to change padding behavior, you should read
|
||||||
:func:`modeling_blenderbot_small._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the
|
:func:`modeling_blenderbot_small._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the
|
||||||
paper <https://arxiv.org/abs/1910.13461>`__ for more information on the default strategy.
|
paper <https://arxiv.org/abs/1910.13461>`__ for more information on the default strategy.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
@@ -606,6 +645,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -629,6 +669,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -675,7 +721,13 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
|||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for encoder_layer in self.layers:
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@@ -695,9 +747,15 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
|||||||
create_custom_forward(encoder_layer),
|
create_custom_forward(encoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
(head_mask[idx] if head_mask is not None else None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -753,6 +811,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -789,6 +849,19 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -868,6 +941,10 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -897,6 +974,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
combined_attention_mask,
|
combined_attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -906,6 +985,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -977,6 +1058,8 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1013,6 +1096,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1032,6 +1116,8 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
|||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1111,6 +1197,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1143,6 +1231,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -1175,7 +1265,14 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past is not None:
|
||||||
@@ -1187,6 +1284,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ class MarianAttention(nn.Module):
|
|||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@@ -241,6 +242,13 @@ class MarianAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit akward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
@@ -290,19 +298,30 @@ class MarianEncoderLayer(nn.Module):
|
|||||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_head_mask: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, attn_weights, _ = self.self_attn(
|
hidden_states, attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -362,6 +381,8 @@ class MarianDecoderLayer(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
@@ -374,6 +395,10 @@ class MarianDecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||||
|
size `(config.encoder_attention_heads,)`.
|
||||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
@@ -389,6 +414,7 @@ class MarianDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -407,6 +433,7 @@ class MarianDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=encoder_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -539,6 +566,18 @@ MARIAN_INPUTS_DOCSTRING = r"""
|
|||||||
If you want to change padding behavior, you should read :func:`modeling_marian._prepare_decoder_inputs` and
|
If you want to change padding behavior, you should read :func:`modeling_marian._prepare_decoder_inputs` and
|
||||||
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||||
information on the default strategy.
|
information on the default strategy.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
@@ -614,6 +653,7 @@ class MarianEncoder(MarianPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -637,6 +677,12 @@ class MarianEncoder(MarianPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -682,7 +728,13 @@ class MarianEncoder(MarianPreTrainedModel):
|
|||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for encoder_layer in self.layers:
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@@ -702,9 +754,15 @@ class MarianEncoder(MarianPreTrainedModel):
|
|||||||
create_custom_forward(encoder_layer),
|
create_custom_forward(encoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
(head_mask[idx] if head_mask is not None else None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -757,6 +815,8 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -793,6 +853,19 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -869,6 +942,12 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -898,6 +977,8 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
combined_attention_mask,
|
combined_attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -907,6 +988,8 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -978,6 +1061,8 @@ class MarianModel(MarianPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1015,6 +1100,7 @@ class MarianModel(MarianPreTrainedModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1034,6 +1120,8 @@ class MarianModel(MarianPreTrainedModel):
|
|||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1118,6 +1206,8 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1151,6 +1241,8 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -1183,7 +1275,14 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past is not None:
|
||||||
@@ -1195,6 +1294,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ class MBartAttention(nn.Module):
|
|||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@@ -234,6 +235,13 @@ class MBartAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit akward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
@@ -282,12 +290,20 @@ class MBartEncoderLayer(nn.Module):
|
|||||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_head_mask: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
@@ -295,7 +311,10 @@ class MBartEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
hidden_states, attn_weights, _ = self.self_attn(
|
hidden_states, attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -353,6 +372,8 @@ class MBartDecoderLayer(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
@@ -365,6 +386,10 @@ class MBartDecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||||
|
size `(config.encoder_attention_heads,)`.
|
||||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
@@ -381,6 +406,7 @@ class MBartDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -399,6 +425,7 @@ class MBartDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -573,6 +600,18 @@ MBART_INPUTS_DOCSTRING = r"""
|
|||||||
If you want to change padding behavior, you should read :func:`modeling_mbart._prepare_decoder_inputs` and
|
If you want to change padding behavior, you should read :func:`modeling_mbart._prepare_decoder_inputs` and
|
||||||
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||||
information on the default strategy.
|
information on the default strategy.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
@@ -651,6 +690,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -674,6 +714,12 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -720,7 +766,13 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for encoder_layer in self.layers:
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@@ -740,9 +792,15 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||||||
create_custom_forward(encoder_layer),
|
create_custom_forward(encoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
(head_mask[idx] if head_mask is not None else None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -800,6 +858,8 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -836,6 +896,19 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -913,6 +986,12 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -942,6 +1021,8 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
combined_attention_mask,
|
combined_attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -951,6 +1032,8 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1029,6 +1112,8 @@ class MBartModel(MBartPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1054,6 +1139,7 @@ class MBartModel(MBartPreTrainedModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1073,6 +1159,8 @@ class MBartModel(MBartPreTrainedModel):
|
|||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1151,6 +1239,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1182,6 +1272,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -1284,6 +1376,8 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds=None,
|
||||||
@@ -1312,6 +1406,8 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -1386,6 +1482,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
start_positions=None,
|
start_positions=None,
|
||||||
end_positions=None,
|
end_positions=None,
|
||||||
@@ -1415,6 +1513,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ class PegasusAttention(nn.Module):
|
|||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@@ -242,6 +243,13 @@ class PegasusAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit akward, but it's required to
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
@@ -291,12 +299,20 @@ class PegasusEncoderLayer(nn.Module):
|
|||||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_head_mask: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
@@ -304,7 +320,10 @@ class PegasusEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
hidden_states, attn_weights, _ = self.self_attn(
|
hidden_states, attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -363,6 +382,8 @@ class PegasusDecoderLayer(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
@@ -375,6 +396,10 @@ class PegasusDecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||||
|
size `(config.encoder_attention_heads,)`.
|
||||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||||
@@ -391,6 +416,7 @@ class PegasusDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -409,6 +435,7 @@ class PegasusDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -540,6 +567,18 @@ PEGASUS_INPUTS_DOCSTRING = r"""
|
|||||||
If you want to change padding behavior, you should read :func:`modeling_pegasus._prepare_decoder_inputs`
|
If you want to change padding behavior, you should read :func:`modeling_pegasus._prepare_decoder_inputs`
|
||||||
and modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
and modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||||
information on the default strategy.
|
information on the default strategy.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
@@ -617,6 +656,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -640,6 +680,12 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -686,7 +732,13 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
|||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for encoder_layer in self.layers:
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@@ -706,9 +758,15 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
|||||||
create_custom_forward(encoder_layer),
|
create_custom_forward(encoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
(head_mask[idx] if head_mask is not None else None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -765,6 +823,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -801,6 +861,19 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -877,6 +950,12 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -906,6 +985,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
combined_attention_mask,
|
combined_attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -915,6 +996,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -988,6 +1071,8 @@ class PegasusModel(PegasusPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1025,6 +1110,7 @@ class PegasusModel(PegasusPreTrainedModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1044,6 +1130,8 @@ class PegasusModel(PegasusPreTrainedModel):
|
|||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1123,6 +1211,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1156,6 +1246,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -1188,7 +1280,14 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past is not None:
|
||||||
@@ -1200,6 +1299,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -53,16 +53,24 @@ def prepare_bart_inputs_dict(
|
|||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.ne(config.pad_token_id)
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": attention_mask,
|
"decoder_attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -142,9 +150,10 @@ class BartModelTester:
|
|||||||
model = BartModel(config=config).get_decoder().to(torch_device).eval()
|
model = BartModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
head_mask = inputs_dict["head_mask"]
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs.to_tuple()
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
@@ -393,7 +402,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -40,16 +40,24 @@ def prepare_blenderbot_inputs_dict(
|
|||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.ne(config.pad_token_id)
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": attention_mask,
|
"decoder_attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -129,9 +137,10 @@ class BlenderbotModelTester:
|
|||||||
model = BlenderbotModel(config=config).get_decoder().to(torch_device).eval()
|
model = BlenderbotModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
head_mask = inputs_dict["head_mask"]
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs.to_tuple()
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
@@ -197,7 +206,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -48,16 +48,24 @@ def prepare_blenderbot_small_inputs_dict(
|
|||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.ne(config.pad_token_id)
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": attention_mask,
|
"decoder_attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -137,9 +145,10 @@ class BlenderbotSmallModelTester:
|
|||||||
model = BlenderbotSmallModel(config=config).get_decoder().to(torch_device).eval()
|
model = BlenderbotSmallModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
head_mask = inputs_dict["head_mask"]
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs.to_tuple()
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
@@ -205,7 +214,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
|
|||||||
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -204,9 +204,13 @@ class ModelTesterMixin:
|
|||||||
"attention_mask",
|
"attention_mask",
|
||||||
"decoder_input_ids",
|
"decoder_input_ids",
|
||||||
"decoder_attention_mask",
|
"decoder_attention_mask",
|
||||||
"encoder_outputs",
|
|
||||||
]
|
]
|
||||||
self.assertListEqual(arg_names[:5], expected_arg_names)
|
expected_arg_names.extend(
|
||||||
|
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
||||||
|
if "head_mask" and "decoder_head_mask" in arg_names
|
||||||
|
else ["encoder_outputs"]
|
||||||
|
)
|
||||||
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
else:
|
else:
|
||||||
expected_arg_names = ["input_ids"]
|
expected_arg_names = ["input_ids"]
|
||||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
@@ -395,7 +399,6 @@ class ModelTesterMixin:
|
|||||||
attention_mask = inputs["attention_mask"]
|
attention_mask = inputs["attention_mask"]
|
||||||
decoder_input_ids = inputs["decoder_input_ids"]
|
decoder_input_ids = inputs["decoder_input_ids"]
|
||||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||||
|
|
||||||
traced_model = torch.jit.trace(
|
traced_model = torch.jit.trace(
|
||||||
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
)
|
)
|
||||||
@@ -465,6 +468,11 @@ class ModelTesterMixin:
|
|||||||
head_mask.requires_grad_(requires_grad=True)
|
head_mask.requires_grad_(requires_grad=True)
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||||
inputs["head_mask"] = head_mask
|
inputs["head_mask"] = head_mask
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||||
|
inputs["decoder_head_mask"] = head_mask
|
||||||
|
|
||||||
outputs = model(**inputs, return_dict=True)
|
outputs = model(**inputs, return_dict=True)
|
||||||
|
|
||||||
@@ -474,8 +482,10 @@ class ModelTesterMixin:
|
|||||||
output.backward()
|
output.backward()
|
||||||
multihead_outputs = head_mask.grad
|
multihead_outputs = head_mask.grad
|
||||||
|
|
||||||
attentions = outputs[-1]
|
self.assertIsNotNone(multihead_outputs)
|
||||||
|
self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
def check_attentions_validity(attentions):
|
||||||
# Remove Nan
|
# Remove Nan
|
||||||
for t in attentions:
|
for t in attentions:
|
||||||
self.assertLess(
|
self.assertLess(
|
||||||
@@ -485,14 +495,19 @@ class ModelTesterMixin:
|
|||||||
t.masked_fill(torch.isnan(t), 0.0) for t in attentions
|
t.masked_fill(torch.isnan(t), 0.0) for t in attentions
|
||||||
] # remove them (the test is less complete)
|
] # remove them (the test is less complete)
|
||||||
|
|
||||||
self.assertIsNotNone(multihead_outputs)
|
|
||||||
self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
|
|
||||||
self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
|
self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
|
||||||
self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
|
self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||||
|
if len(attentions) > 2: # encoder-decoder models have only 2 layers in each module
|
||||||
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
|
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
|
||||||
self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
|
self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
|
||||||
self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
check_attentions_validity(outputs.encoder_attentions)
|
||||||
|
check_attentions_validity(outputs.decoder_attentions)
|
||||||
|
else:
|
||||||
|
check_attentions_validity(outputs.attentions)
|
||||||
|
|
||||||
def test_head_pruning(self):
|
def test_head_pruning(self):
|
||||||
if not self.test_pruning:
|
if not self.test_pruning:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -54,16 +54,24 @@ def prepare_marian_inputs_dict(
|
|||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.ne(config.pad_token_id)
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": attention_mask,
|
"decoder_attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -146,9 +154,10 @@ class MarianModelTester:
|
|||||||
model = MarianModel(config=config).get_decoder().to(torch_device).eval()
|
model = MarianModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
head_mask = inputs_dict["head_mask"]
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs.to_tuple()
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
@@ -214,7 +223,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
|||||||
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -49,16 +49,24 @@ def prepare_mbart_inputs_dict(
|
|||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.ne(config.pad_token_id)
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": attention_mask,
|
"decoder_attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -138,9 +146,10 @@ class MBartModelTester:
|
|||||||
model = MBartModel(config=config).get_decoder().to(torch_device).eval()
|
model = MBartModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
head_mask = inputs_dict["head_mask"]
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs.to_tuple()
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
@@ -210,7 +219,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
|||||||
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -41,16 +41,24 @@ def prepare_pegasus_inputs_dict(
|
|||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.ne(config.pad_token_id)
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": attention_mask,
|
"decoder_attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -130,9 +138,10 @@ class PegasusModelTester:
|
|||||||
model = PegasusModel(config=config).get_decoder().to(torch_device).eval()
|
model = PegasusModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
head_mask = inputs_dict["head_mask"]
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs.to_tuple()
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
@@ -198,7 +207,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = True
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user