Add head_mask and decoder_head_mask to FSMT (#9819)
* Add {decoder_,}head_mask to fsmt_modeling.py
* Enable test_headmasking and some changes to docs
* Remove test_head_masking flag from fsmt test file
Remove test_head_masking flag from test_modeling_fsmt.py
since test_head_masking is set to be True by default (thus it is redundant to store).
* Merge master and remove test_head_masking = True
* Rebase necessary due to an update of jaxlib
* Remove test_head_masking=True in tests/test_modeling_fsmt.py
as it is redundant.
This commit is contained in:
@@ -240,6 +240,17 @@ FSMT_INPUTS_DOCSTRING = r"""
|
||||
also be used by default. If you want to change padding behavior, you should read
|
||||
:func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify. See diagram 1 in the paper for more info on
|
||||
the default strategy
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||
@@ -282,7 +293,11 @@ def triu_onnx(x, diagonal=0):
|
||||
|
||||
|
||||
def _prepare_fsmt_decoder_inputs(
|
||||
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids=None,
|
||||
decoder_padding_mask=None,
|
||||
causal_mask_dtype=torch.float32,
|
||||
):
|
||||
"""
|
||||
Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
|
||||
@@ -377,21 +392,27 @@ class EncoderLayer(nn.Module):
|
||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(self, x, encoder_padding_mask, output_attentions=False):
|
||||
def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=False):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
||||
x (:obj:`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_padding_mask (:obj:`torch.ByteTensor`): binary ByteTensor of shape
|
||||
`(batch, src_len)` where padding elements are indicated by ``1``.
|
||||
for t_tgt, t_src is excluded (or masked out), =0 means it is
|
||||
included in attention
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x, attn_weights = self.self_attn(
|
||||
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
|
||||
query=x,
|
||||
key=x,
|
||||
key_padding_mask=encoder_padding_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
@@ -432,21 +453,32 @@ class FSMTEncoder(nn.Module):
|
||||
) # type: List[EncoderLayer]
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_ids (LongTensor): tokens in the source language of shape
|
||||
input_ids (:obj:`torch.LongTensor`): tokens in the source language of shape
|
||||
`(batch, src_len)`
|
||||
attention_mask (torch.LongTensor): indicating which indices are padding tokens
|
||||
attention_mask (:obj:`torch.LongTensor`): indicating which indices are padding tokens
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
Returns:
|
||||
BaseModelOutput or Tuple comprised of:
|
||||
|
||||
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
|
||||
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate hidden states of shape `(src_len,
|
||||
batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
|
||||
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
|
||||
- **x** (:obj:`torch.Tensor`): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
|
||||
- **encoder_states** (:obj:`Tuple(torch.FloatTensor`)): all intermediate hidden states of shape
|
||||
`(src_len, batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
|
||||
- **all_attentions** (:obj:`Tuple(torch.FloatTensor`)): Attention weights for each layer.
|
||||
During training might not be of length n_layers because of layer dropout.
|
||||
"""
|
||||
# check attention mask and invert
|
||||
@@ -463,7 +495,12 @@ class FSMTEncoder(nn.Module):
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
for encoder_layer in self.layers:
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
x = x.transpose(0, 1) # T x B x C -> B x T x C
|
||||
encoder_states += (x,)
|
||||
@@ -473,7 +510,12 @@ class FSMTEncoder(nn.Module):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
attn = None
|
||||
else:
|
||||
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
|
||||
x, attn = encoder_layer(
|
||||
x,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (attn,)
|
||||
@@ -522,6 +564,8 @@ class DecoderLayer(nn.Module):
|
||||
encoder_attn_mask=None,
|
||||
layer_state=None,
|
||||
causal_mask=None,
|
||||
layer_head_mask=None,
|
||||
encoder_layer_head_mask=None,
|
||||
decoder_padding_mask=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
@@ -537,6 +581,7 @@ class DecoderLayer(nn.Module):
|
||||
layer_state=layer_state, # adds keys to layer state
|
||||
key_padding_mask=decoder_padding_mask,
|
||||
attn_mask=causal_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
@@ -551,6 +596,7 @@ class DecoderLayer(nn.Module):
|
||||
key=encoder_hidden_states,
|
||||
key_padding_mask=encoder_attn_mask,
|
||||
layer_state=layer_state, # mutates layer state
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
@@ -611,6 +657,8 @@ class FSMTDecoder(nn.Module):
|
||||
encoder_padding_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
@@ -622,12 +670,24 @@ class FSMTDecoder(nn.Module):
|
||||
EMNLP 2019).
|
||||
|
||||
Args:
|
||||
input_ids (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch, tgt_len)`):
|
||||
previous decoder outputs for teacher forcing
|
||||
encoder_hidden_states: output from the encoder, used for
|
||||
encoder-side attention
|
||||
encoder_padding_mask: for ignoring pad tokens
|
||||
past_key_values (dict or None): dictionary used for storing state during generation
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
Returns:
|
||||
BaseModelOutputWithPast or tuple:
|
||||
@@ -662,6 +722,12 @@ class FSMTDecoder(nn.Module):
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attns = () if output_attentions else None
|
||||
next_decoder_cache = []
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@@ -681,6 +747,8 @@ class FSMTDecoder(nn.Module):
|
||||
decoder_padding_mask=decoder_padding_mask,
|
||||
layer_state=layer_state,
|
||||
causal_mask=decoder_causal_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
@@ -761,6 +829,7 @@ class Attention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
layer_head_mask: Optional[Tensor] = None,
|
||||
output_attentions=False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||
@@ -830,6 +899,13 @@ class Attention(nn.Module):
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
assert layer_head_mask.size() == (
|
||||
self.num_heads,
|
||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# make sure that attn_weights are included in graph
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
@@ -923,6 +999,8 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Tuple] = None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
@@ -958,6 +1036,7 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -977,6 +1056,8 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
attention_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask=causal_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@@ -1052,6 +1133,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
labels=None,
|
||||
@@ -1080,6 +1163,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
||||
Reference in New Issue
Block a user