From 0c6c0afc0e06bc2e8277e16bc5e57e470232f63b Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 1 Feb 2021 07:30:21 +0100 Subject: [PATCH] Add head_mask and decoder_head_mask to FSMT (#9819) * Add {decoder_,}head_mask to fsmt_modeling.py * Enable test_headmasking and some changes to docs * Remove test_head_masking flag from fsmt test file Remove test_head_masking flag from test_modeling_fsmt.py since test_head_masking is set to be True by default (thus it is redundant to store). * Merge master and remove test_head_masking = True * Rebase necessary due to an update of jaxlib * Remove test_head_masking=True in tests/test_modeling_fsmt.py as it is redundant. --- src/transformers/models/fsmt/modeling_fsmt.py | 117 +++++++++++++++--- tests/test_modeling_fsmt.py | 9 +- 2 files changed, 109 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index be3b102255..957ba4e84e 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -240,6 +240,17 @@ FSMT_INPUTS_DOCSTRING = r""" also be used by default. If you want to change padding behavior, you should read :func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify. See diagram 1 in the paper for more info on the default strategy + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a @@ -282,7 +293,11 @@ def triu_onnx(x, diagonal=0): def _prepare_fsmt_decoder_inputs( - config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 + config, + input_ids, + decoder_input_ids=None, + decoder_padding_mask=None, + causal_mask_dtype=torch.float32, ): """ Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided. @@ -377,21 +392,27 @@ class EncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) - def forward(self, x, encoder_padding_mask, output_attentions=False): + def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=False): """ Args: - x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_padding_mask (ByteTensor): binary ByteTensor of shape + x (:obj:`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (:obj:`torch.ByteTensor`): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. for t_tgt, t_src is excluded (or masked out), =0 means it is included in attention + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x x, attn_weights = self.self_attn( - query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions + query=x, + key=x, + key_padding_mask=encoder_padding_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -432,21 +453,32 @@ class FSMTEncoder(nn.Module): ) # type: List[EncoderLayer] def forward( - self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True + self, + input_ids, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, ): """ Args: - input_ids (LongTensor): tokens in the source language of shape + input_ids (:obj:`torch.LongTensor`): tokens in the source language of shape `(batch, src_len)` - attention_mask (torch.LongTensor): indicating which indices are padding tokens + attention_mask (:obj:`torch.LongTensor`): indicating which indices are padding tokens + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. Returns: BaseModelOutput or Tuple comprised of: - - **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - - **encoder_states** (tuple(torch.FloatTensor)): all intermediate hidden states of shape `(src_len, - batch, embed_dim)`. Only populated if *output_hidden_states:* is True. - - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer. + - **x** (:obj:`torch.Tensor`): the last encoder layer's output of shape `(src_len, batch, embed_dim)` + - **encoder_states** (:obj:`Tuple(torch.FloatTensor`)): all intermediate hidden states of shape + `(src_len, batch, embed_dim)`. Only populated if *output_hidden_states:* is True. + - **all_attentions** (:obj:`Tuple(torch.FloatTensor`)): Attention weights for each layer. During training might not be of length n_layers because of layer dropout. """ # check attention mask and invert @@ -463,7 +495,12 @@ class FSMTEncoder(nn.Module): encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: x = x.transpose(0, 1) # T x B x C -> B x T x C encoder_states += (x,) @@ -473,7 +510,12 @@ class FSMTEncoder(nn.Module): if self.training and (dropout_probability < self.layerdrop): # skip the layer attn = None else: - x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) + x, attn = encoder_layer( + x, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) if output_attentions: all_attentions = all_attentions + (attn,) @@ -522,6 +564,8 @@ class DecoderLayer(nn.Module): encoder_attn_mask=None, layer_state=None, causal_mask=None, + layer_head_mask=None, + encoder_layer_head_mask=None, decoder_padding_mask=None, output_attentions=False, ): @@ -537,6 +581,7 @@ class DecoderLayer(nn.Module): layer_state=layer_state, # adds keys to layer state key_padding_mask=decoder_padding_mask, attn_mask=causal_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) @@ -551,6 +596,7 @@ class DecoderLayer(nn.Module): key=encoder_hidden_states, key_padding_mask=encoder_attn_mask, layer_state=layer_state, # mutates layer state + layer_head_mask=encoder_layer_head_mask, output_attentions=output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) @@ -611,6 +657,8 @@ class FSMTDecoder(nn.Module): encoder_padding_mask, decoder_padding_mask, decoder_causal_mask, + head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=False, output_attentions=False, @@ -622,12 +670,24 @@ class FSMTDecoder(nn.Module): EMNLP 2019). Args: - input_ids (LongTensor): previous decoder outputs of shape - `(batch, tgt_len)`, for teacher forcing + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch, tgt_len)`): + previous decoder outputs for teacher forcing encoder_hidden_states: output from the encoder, used for encoder-side attention encoder_padding_mask: for ignoring pad tokens past_key_values (dict or None): dictionary used for storing state during generation + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. Returns: BaseModelOutputWithPast or tuple: @@ -662,6 +722,12 @@ class FSMTDecoder(nn.Module): all_self_attns = () if output_attentions else None all_cross_attns = () if output_attentions else None next_decoder_cache = [] + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -681,6 +747,8 @@ class FSMTDecoder(nn.Module): decoder_padding_mask=decoder_padding_mask, layer_state=layer_state, causal_mask=decoder_causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), output_attentions=output_attentions, ) @@ -761,6 +829,7 @@ class Attention(nn.Module): key_padding_mask: Optional[Tensor] = None, layer_state: Optional[Dict[str, Optional[Tensor]]] = None, attn_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, output_attentions=False, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time(SeqLen) x Batch x Channel""" @@ -830,6 +899,13 @@ class Attention(nn.Module): attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: # make sure that attn_weights are included in graph attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) @@ -923,6 +999,8 @@ class FSMTModel(PretrainedFSMTModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Tuple] = None, past_key_values=None, use_cache=None, @@ -958,6 +1036,7 @@ class FSMTModel(PretrainedFSMTModel): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -977,6 +1056,8 @@ class FSMTModel(PretrainedFSMTModel): attention_mask, decoder_padding_mask, decoder_causal_mask=causal_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, @@ -1052,6 +1133,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, labels=None, @@ -1080,6 +1163,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, diff --git a/tests/test_modeling_fsmt.py b/tests/test_modeling_fsmt.py index 4f0e9c9ecb..860e888023 100644 --- a/tests/test_modeling_fsmt.py +++ b/tests/test_modeling_fsmt.py @@ -111,12 +111,20 @@ def prepare_fsmt_inputs_dict( config, input_ids, attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -126,7 +134,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False test_missing_keys = False def setUp(self):