From 357fb1c5d8b6a16f042f9b504f023d935086e8e5 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 18 Jan 2021 13:35:22 +0100 Subject: [PATCH] Add head_mask/decoder_head_mask for BART (#9569) * Add head_mask/decoder_head_mask for BART This branch implement head_mask and decoder_head_mask for BART-based models. Full list below: - BART - MBart - Blenderbot - BlenderbotSmall - Marian - Pegasus Everything is accompanied with updated testing. * Fix test_headmasking for BART models * Fix text_headmasking for BART-like models which has only 2 layers in each modules. The condition ``` self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0) ``` is, therefore, invalid for encoder-decoder models considering the `head_mask` ``` head_mask = torch.ones( self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device, ) head_mask[0, 0] = 0 head_mask[-1, :-1] = 0 ``` specified in the `test_headmasking` test/function. * Adjust test_modeling_common.py to reflect T5 input args * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make style * make fix-copies Co-authored-by: Patrick von Platen Co-authored-by: Lysandre Debut Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/bart/modeling_bart.py | 118 +++++++++++++++++- .../models/blenderbot/modeling_blenderbot.py | 110 +++++++++++++++- .../modeling_blenderbot_small.py | 108 +++++++++++++++- .../models/marian/modeling_marian.py | 110 +++++++++++++++- .../models/mbart/modeling_mbart.py | 108 +++++++++++++++- .../models/pegasus/modeling_pegasus.py | 110 +++++++++++++++- tests/test_modeling_bart.py | 13 +- tests/test_modeling_blenderbot.py | 13 +- tests/test_modeling_blenderbot_small.py | 13 +- tests/test_modeling_common.py | 53 +++++--- tests/test_modeling_marian.py | 13 +- tests/test_modeling_mbart.py | 13 +- tests/test_modeling_pegasus.py | 13 +- 13 files changed, 735 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 392aeca466..b34d81741d 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -160,6 +160,7 @@ class BartAttention(nn.Module): key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -227,6 +228,13 @@ class BartAttention(nn.Module): attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: # this operation is a bit akward, but it's required to # make sure that attn_weights keeps its gradient. @@ -275,19 +283,30 @@ class BartEncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): """ Args: hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. """ residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -346,6 +365,8 @@ class BartDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -358,6 +379,10 @@ class BartDecoderLayer(nn.Module): encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -373,6 +398,7 @@ class BartDecoderLayer(nn.Module): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -391,6 +417,7 @@ class BartDecoderLayer(nn.Module): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -568,6 +595,18 @@ BART_INPUTS_DOCSTRING = r""" If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -645,6 +684,7 @@ class BartEncoder(BartPretrainedModel): self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -668,6 +708,12 @@ class BartEncoder(BartPretrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -714,7 +760,13 @@ class BartEncoder(BartPretrainedModel): encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -734,9 +786,15 @@ class BartEncoder(BartPretrainedModel): create_custom_forward(encoder_layer), hidden_states, attention_mask, + (head_mask[idx] if head_mask is not None else None), ) else: - layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -791,6 +849,8 @@ class BartDecoder(BartPretrainedModel): attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -827,6 +887,19 @@ class BartDecoder(BartPretrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -904,6 +977,12 @@ class BartDecoder(BartPretrainedModel): all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -933,6 +1012,8 @@ class BartDecoder(BartPretrainedModel): combined_attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: @@ -942,6 +1023,8 @@ class BartDecoder(BartPretrainedModel): attention_mask=combined_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1018,6 +1101,8 @@ class BartModel(BartPretrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1046,6 +1131,7 @@ class BartModel(BartPretrainedModel): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1065,6 +1151,8 @@ class BartModel(BartPretrainedModel): attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1143,6 +1231,8 @@ class BartForConditionalGeneration(BartPretrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1175,6 +1265,8 @@ class BartForConditionalGeneration(BartPretrainedModel): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1207,7 +1299,14 @@ class BartForConditionalGeneration(BartPretrainedModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): # cut decoder_input_ids if past is used if past is not None: @@ -1219,6 +1318,7 @@ class BartForConditionalGeneration(BartPretrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } @@ -1278,6 +1378,8 @@ class BartForSequenceClassification(BartPretrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, inputs_embeds=None, decoder_inputs_embeds=None, @@ -1306,6 +1408,8 @@ class BartForSequenceClassification(BartPretrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1379,6 +1483,8 @@ class BartForQuestionAnswering(BartPretrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, start_positions=None, end_positions=None, @@ -1408,6 +1514,8 @@ class BartForQuestionAnswering(BartPretrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 0373c67f35..a47017278e 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -160,6 +160,7 @@ class BlenderbotAttention(nn.Module): key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -227,6 +228,13 @@ class BlenderbotAttention(nn.Module): attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: # this operation is a bit akward, but it's required to # make sure that attn_weights keeps its gradient. @@ -276,12 +284,20 @@ class BlenderbotEncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): """ Args: hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -289,7 +305,10 @@ class BlenderbotEncoderLayer(nn.Module): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -348,6 +367,8 @@ class BlenderbotDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -360,6 +381,10 @@ class BlenderbotDecoderLayer(nn.Module): encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -376,6 +401,7 @@ class BlenderbotDecoderLayer(nn.Module): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -394,6 +420,7 @@ class BlenderbotDecoderLayer(nn.Module): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -528,6 +555,18 @@ BLENDERBOT_INPUTS_DOCSTRING = r""" If you want to change padding behavior, you should read :func:`modeling_blenderbot._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -605,6 +644,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -628,6 +668,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -673,7 +719,13 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -693,9 +745,15 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): create_custom_forward(encoder_layer), hidden_states, attention_mask, + (head_mask[idx] if head_mask is not None else None), ) else: - layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -751,8 +809,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): self, input_ids=None, attention_mask=None, + head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -778,6 +838,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -789,6 +855,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -866,6 +939,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -895,6 +974,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): combined_attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: @@ -902,8 +983,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): layer_outputs = decoder_layer( hidden_states, attention_mask=combined_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -989,6 +1072,8 @@ class BlenderbotModel(BlenderbotPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1025,6 +1110,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1044,6 +1130,8 @@ class BlenderbotModel(BlenderbotPreTrainedModel): attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1135,6 +1223,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1167,6 +1257,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1199,7 +1291,14 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): # cut decoder_input_ids if past is used if past is not None: @@ -1211,6 +1310,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 1a0aac32d5..dd8e1020cc 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -158,6 +158,7 @@ class BlenderbotSmallAttention(nn.Module): key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -225,6 +226,13 @@ class BlenderbotSmallAttention(nn.Module): attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: # this operation is a bit akward, but it's required to # make sure that attn_weights keeps its gradient. @@ -274,19 +282,30 @@ class BlenderbotSmallEncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): """ Args: hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. """ residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -346,6 +365,8 @@ class BlenderbotSmallDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -358,6 +379,10 @@ class BlenderbotSmallDecoderLayer(nn.Module): encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -373,6 +398,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -391,6 +417,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -529,6 +556,18 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r""" If you want to change padding behavior, you should read :func:`modeling_blenderbot_small._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -606,6 +645,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -629,6 +669,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -675,7 +721,13 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -695,9 +747,15 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): create_custom_forward(encoder_layer), hidden_states, attention_mask, + (head_mask[idx] if head_mask is not None else None), ) else: - layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -753,6 +811,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -789,6 +849,19 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -868,6 +941,10 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -897,6 +974,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): combined_attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: @@ -906,6 +985,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): attention_mask=combined_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -977,6 +1058,8 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1013,6 +1096,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1032,6 +1116,8 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1111,6 +1197,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1143,6 +1231,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1175,7 +1265,14 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): # cut decoder_input_ids if past is used if past is not None: @@ -1187,6 +1284,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 33caacfebd..6e8bdf0d0c 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -174,6 +174,7 @@ class MarianAttention(nn.Module): key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -241,6 +242,13 @@ class MarianAttention(nn.Module): attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: # this operation is a bit akward, but it's required to # make sure that attn_weights keeps its gradient. @@ -290,19 +298,30 @@ class MarianEncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): """ Args: hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. """ residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -362,6 +381,8 @@ class MarianDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -374,6 +395,10 @@ class MarianDecoderLayer(nn.Module): encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -389,6 +414,7 @@ class MarianDecoderLayer(nn.Module): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -407,6 +433,7 @@ class MarianDecoderLayer(nn.Module): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -539,6 +566,18 @@ MARIAN_INPUTS_DOCSTRING = r""" If you want to change padding behavior, you should read :func:`modeling_marian._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -614,6 +653,7 @@ class MarianEncoder(MarianPreTrainedModel): self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -637,6 +677,12 @@ class MarianEncoder(MarianPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -682,7 +728,13 @@ class MarianEncoder(MarianPreTrainedModel): encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -702,9 +754,15 @@ class MarianEncoder(MarianPreTrainedModel): create_custom_forward(encoder_layer), hidden_states, attention_mask, + (head_mask[idx] if head_mask is not None else None), ) else: - layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -757,6 +815,8 @@ class MarianDecoder(MarianPreTrainedModel): attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -793,6 +853,19 @@ class MarianDecoder(MarianPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -869,6 +942,12 @@ class MarianDecoder(MarianPreTrainedModel): all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -898,6 +977,8 @@ class MarianDecoder(MarianPreTrainedModel): combined_attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: @@ -907,6 +988,8 @@ class MarianDecoder(MarianPreTrainedModel): attention_mask=combined_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -978,6 +1061,8 @@ class MarianModel(MarianPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1015,6 +1100,7 @@ class MarianModel(MarianPreTrainedModel): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1034,6 +1120,8 @@ class MarianModel(MarianPreTrainedModel): attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1118,6 +1206,8 @@ class MarianMTModel(MarianPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1151,6 +1241,8 @@ class MarianMTModel(MarianPreTrainedModel): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1183,7 +1275,14 @@ class MarianMTModel(MarianPreTrainedModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): # cut decoder_input_ids if past is used if past is not None: @@ -1195,6 +1294,7 @@ class MarianMTModel(MarianPreTrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 53e4833339..ac631057d8 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -167,6 +167,7 @@ class MBartAttention(nn.Module): key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -234,6 +235,13 @@ class MBartAttention(nn.Module): attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: # this operation is a bit akward, but it's required to # make sure that attn_weights keeps its gradient. @@ -282,12 +290,20 @@ class MBartEncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): """ Args: hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -295,7 +311,10 @@ class MBartEncoderLayer(nn.Module): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -353,6 +372,8 @@ class MBartDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -365,6 +386,10 @@ class MBartDecoderLayer(nn.Module): encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -381,6 +406,7 @@ class MBartDecoderLayer(nn.Module): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -399,6 +425,7 @@ class MBartDecoderLayer(nn.Module): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -573,6 +600,18 @@ MBART_INPUTS_DOCSTRING = r""" If you want to change padding behavior, you should read :func:`modeling_mbart._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -651,6 +690,7 @@ class MBartEncoder(MBartPreTrainedModel): self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -674,6 +714,12 @@ class MBartEncoder(MBartPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -720,7 +766,13 @@ class MBartEncoder(MBartPreTrainedModel): encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -740,9 +792,15 @@ class MBartEncoder(MBartPreTrainedModel): create_custom_forward(encoder_layer), hidden_states, attention_mask, + (head_mask[idx] if head_mask is not None else None), ) else: - layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -800,6 +858,8 @@ class MBartDecoder(MBartPreTrainedModel): attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -836,6 +896,19 @@ class MBartDecoder(MBartPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -913,6 +986,12 @@ class MBartDecoder(MBartPreTrainedModel): all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -942,6 +1021,8 @@ class MBartDecoder(MBartPreTrainedModel): combined_attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: @@ -951,6 +1032,8 @@ class MBartDecoder(MBartPreTrainedModel): attention_mask=combined_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1029,6 +1112,8 @@ class MBartModel(MBartPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1054,6 +1139,7 @@ class MBartModel(MBartPreTrainedModel): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1073,6 +1159,8 @@ class MBartModel(MBartPreTrainedModel): attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1151,6 +1239,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1182,6 +1272,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1284,6 +1376,8 @@ class MBartForSequenceClassification(MBartPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, inputs_embeds=None, decoder_inputs_embeds=None, @@ -1312,6 +1406,8 @@ class MBartForSequenceClassification(MBartPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1386,6 +1482,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, start_positions=None, end_positions=None, @@ -1415,6 +1513,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index a29fdfd49f..13deb70da9 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -175,6 +175,7 @@ class PegasusAttention(nn.Module): key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -242,6 +243,13 @@ class PegasusAttention(nn.Module): attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: # this operation is a bit akward, but it's required to # make sure that attn_weights keeps its gradient. @@ -291,12 +299,20 @@ class PegasusEncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): """ Args: hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -304,7 +320,10 @@ class PegasusEncoderLayer(nn.Module): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -363,6 +382,8 @@ class PegasusDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -375,6 +396,10 @@ class PegasusDecoderLayer(nn.Module): encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -391,6 +416,7 @@ class PegasusDecoderLayer(nn.Module): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -409,6 +435,7 @@ class PegasusDecoderLayer(nn.Module): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -540,6 +567,18 @@ PEGASUS_INPUTS_DOCSTRING = r""" If you want to change padding behavior, you should read :func:`modeling_pegasus._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -617,6 +656,7 @@ class PegasusEncoder(PegasusPreTrainedModel): self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -640,6 +680,12 @@ class PegasusEncoder(PegasusPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -686,7 +732,13 @@ class PegasusEncoder(PegasusPreTrainedModel): encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -706,9 +758,15 @@ class PegasusEncoder(PegasusPreTrainedModel): create_custom_forward(encoder_layer), hidden_states, attention_mask, + (head_mask[idx] if head_mask is not None else None), ) else: - layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -765,6 +823,8 @@ class PegasusDecoder(PegasusPreTrainedModel): attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -801,6 +861,19 @@ class PegasusDecoder(PegasusPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -877,6 +950,12 @@ class PegasusDecoder(PegasusPreTrainedModel): all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -906,6 +985,8 @@ class PegasusDecoder(PegasusPreTrainedModel): combined_attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: @@ -915,6 +996,8 @@ class PegasusDecoder(PegasusPreTrainedModel): attention_mask=combined_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -988,6 +1071,8 @@ class PegasusModel(PegasusPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1025,6 +1110,7 @@ class PegasusModel(PegasusPreTrainedModel): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1044,6 +1130,8 @@ class PegasusModel(PegasusPreTrainedModel): attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1123,6 +1211,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1156,6 +1246,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1188,7 +1280,14 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): # cut decoder_input_ids if past is used if past is not None: @@ -1200,6 +1299,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 1100c893ae..7bf7bde03b 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -53,16 +53,24 @@ def prepare_bart_inputs_dict( decoder_input_ids=None, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -142,9 +150,10 @@ class BartModelTester: model = BartModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -393,7 +402,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False + test_head_masking = True test_missing_keys = False def setUp(self): diff --git a/tests/test_modeling_blenderbot.py b/tests/test_modeling_blenderbot.py index b72cacf711..826d43afff 100644 --- a/tests/test_modeling_blenderbot.py +++ b/tests/test_modeling_blenderbot.py @@ -40,16 +40,24 @@ def prepare_blenderbot_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -129,9 +137,10 @@ class BlenderbotModelTester: model = BlenderbotModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -197,7 +206,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False + test_head_masking = True test_missing_keys = False def setUp(self): diff --git a/tests/test_modeling_blenderbot_small.py b/tests/test_modeling_blenderbot_small.py index c3538c7be1..abb223413a 100644 --- a/tests/test_modeling_blenderbot_small.py +++ b/tests/test_modeling_blenderbot_small.py @@ -48,16 +48,24 @@ def prepare_blenderbot_small_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -137,9 +145,10 @@ class BlenderbotSmallModelTester: model = BlenderbotSmallModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -205,7 +214,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False + test_head_masking = True test_missing_keys = False def setUp(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e33efd34b4..d14cab8e69 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -204,9 +204,13 @@ class ModelTesterMixin: "attention_mask", "decoder_input_ids", "decoder_attention_mask", - "encoder_outputs", ] - self.assertListEqual(arg_names[:5], expected_arg_names) + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" in arg_names + else ["encoder_outputs"] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) else: expected_arg_names = ["input_ids"] self.assertListEqual(arg_names[:1], expected_arg_names) @@ -395,7 +399,6 @@ class ModelTesterMixin: attention_mask = inputs["attention_mask"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] - traced_model = torch.jit.trace( model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) ) @@ -465,6 +468,11 @@ class ModelTesterMixin: head_mask.requires_grad_(requires_grad=True) inputs = self._prepare_for_class(inputs_dict, model_class).copy() inputs["head_mask"] = head_mask + if model.config.is_encoder_decoder: + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model + inputs["decoder_head_mask"] = head_mask outputs = model(**inputs, return_dict=True) @@ -474,24 +482,31 @@ class ModelTesterMixin: output.backward() multihead_outputs = head_mask.grad - attentions = outputs[-1] - - # Remove Nan - for t in attentions: - self.assertLess( - torch.sum(torch.isnan(t)), t.numel() / 4 - ) # Check we don't have more than 25% nans (arbitrary) - attentions = [ - t.masked_fill(torch.isnan(t), 0.0) for t in attentions - ] # remove them (the test is less complete) - self.assertIsNotNone(multihead_outputs) self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers) - self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0) - self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0) - self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0) - self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0) - self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0) + + def check_attentions_validity(attentions): + # Remove Nan + for t in attentions: + self.assertLess( + torch.sum(torch.isnan(t)), t.numel() / 4 + ) # Check we don't have more than 25% nans (arbitrary) + attentions = [ + t.masked_fill(torch.isnan(t), 0.0) for t in attentions + ] # remove them (the test is less complete) + + self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0) + self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0) + if len(attentions) > 2: # encoder-decoder models have only 2 layers in each module + self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0) + self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0) + self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0) + + if model.config.is_encoder_decoder: + check_attentions_validity(outputs.encoder_attentions) + check_attentions_validity(outputs.decoder_attentions) + else: + check_attentions_validity(outputs.attentions) def test_head_pruning(self): if not self.test_pruning: diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index e5a815d963..02313b20dc 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -54,16 +54,24 @@ def prepare_marian_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -146,9 +154,10 @@ class MarianModelTester: model = MarianModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -214,7 +223,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase all_generative_model_classes = (MarianMTModel,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False + test_head_masking = True test_missing_keys = False def setUp(self): diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index d70362edd8..7629e2beeb 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -49,16 +49,24 @@ def prepare_mbart_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -138,9 +146,10 @@ class MBartModelTester: model = MBartModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -210,7 +219,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False + test_head_masking = True test_missing_keys = False def setUp(self): diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 512b50ff75..823cb00453 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -41,16 +41,24 @@ def prepare_pegasus_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -130,9 +138,10 @@ class PegasusModelTester: model = PegasusModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -198,7 +207,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False + test_head_masking = True test_missing_keys = False def setUp(self):