Add head_mask and decoder_head_mask to FSMT (#9819)
* Add {decoder_,}head_mask to fsmt_modeling.py
* Enable test_headmasking and some changes to docs
* Remove test_head_masking flag from fsmt test file
Remove test_head_masking flag from test_modeling_fsmt.py
since test_head_masking is set to be True by default (thus it is redundant to store).
* Merge master and remove test_head_masking = True
* Rebase necessary due to an update of jaxlib
* Remove test_head_masking=True in tests/test_modeling_fsmt.py
as it is redundant.
This commit is contained in:
@@ -240,6 +240,17 @@ FSMT_INPUTS_DOCSTRING = r"""
|
|||||||
also be used by default. If you want to change padding behavior, you should read
|
also be used by default. If you want to change padding behavior, you should read
|
||||||
:func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify. See diagram 1 in the paper for more info on
|
:func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify. See diagram 1 in the paper for more info on
|
||||||
the default strategy
|
the default strategy
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||||
@@ -282,7 +293,11 @@ def triu_onnx(x, diagonal=0):
|
|||||||
|
|
||||||
|
|
||||||
def _prepare_fsmt_decoder_inputs(
|
def _prepare_fsmt_decoder_inputs(
|
||||||
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids=None,
|
||||||
|
decoder_padding_mask=None,
|
||||||
|
causal_mask_dtype=torch.float32,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
|
Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
|
||||||
@@ -377,21 +392,27 @@ class EncoderLayer(nn.Module):
|
|||||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, x, encoder_padding_mask, output_attentions=False):
|
def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
x (:obj:`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
encoder_padding_mask (:obj:`torch.ByteTensor`): binary ByteTensor of shape
|
||||||
`(batch, src_len)` where padding elements are indicated by ``1``.
|
`(batch, src_len)` where padding elements are indicated by ``1``.
|
||||||
for t_tgt, t_src is excluded (or masked out), =0 means it is
|
for t_tgt, t_src is excluded (or masked out), =0 means it is
|
||||||
included in attention
|
included in attention
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||||
"""
|
"""
|
||||||
residual = x
|
residual = x
|
||||||
x, attn_weights = self.self_attn(
|
x, attn_weights = self.self_attn(
|
||||||
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
|
query=x,
|
||||||
|
key=x,
|
||||||
|
key_padding_mask=encoder_padding_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
@@ -432,21 +453,32 @@ class FSMTEncoder(nn.Module):
|
|||||||
) # type: List[EncoderLayer]
|
) # type: List[EncoderLayer]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (LongTensor): tokens in the source language of shape
|
input_ids (:obj:`torch.LongTensor`): tokens in the source language of shape
|
||||||
`(batch, src_len)`
|
`(batch, src_len)`
|
||||||
attention_mask (torch.LongTensor): indicating which indices are padding tokens
|
attention_mask (:obj:`torch.LongTensor`): indicating which indices are padding tokens
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseModelOutput or Tuple comprised of:
|
BaseModelOutput or Tuple comprised of:
|
||||||
|
|
||||||
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
|
- **x** (:obj:`torch.Tensor`): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
|
||||||
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate hidden states of shape `(src_len,
|
- **encoder_states** (:obj:`Tuple(torch.FloatTensor`)): all intermediate hidden states of shape
|
||||||
batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
|
`(src_len, batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
|
||||||
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
|
- **all_attentions** (:obj:`Tuple(torch.FloatTensor`)): Attention weights for each layer.
|
||||||
During training might not be of length n_layers because of layer dropout.
|
During training might not be of length n_layers because of layer dropout.
|
||||||
"""
|
"""
|
||||||
# check attention mask and invert
|
# check attention mask and invert
|
||||||
@@ -463,7 +495,12 @@ class FSMTEncoder(nn.Module):
|
|||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for encoder_layer in self.layers:
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
x = x.transpose(0, 1) # T x B x C -> B x T x C
|
x = x.transpose(0, 1) # T x B x C -> B x T x C
|
||||||
encoder_states += (x,)
|
encoder_states += (x,)
|
||||||
@@ -473,7 +510,12 @@ class FSMTEncoder(nn.Module):
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
attn = None
|
attn = None
|
||||||
else:
|
else:
|
||||||
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
|
x, attn = encoder_layer(
|
||||||
|
x,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (attn,)
|
all_attentions = all_attentions + (attn,)
|
||||||
@@ -522,6 +564,8 @@ class DecoderLayer(nn.Module):
|
|||||||
encoder_attn_mask=None,
|
encoder_attn_mask=None,
|
||||||
layer_state=None,
|
layer_state=None,
|
||||||
causal_mask=None,
|
causal_mask=None,
|
||||||
|
layer_head_mask=None,
|
||||||
|
encoder_layer_head_mask=None,
|
||||||
decoder_padding_mask=None,
|
decoder_padding_mask=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
@@ -537,6 +581,7 @@ class DecoderLayer(nn.Module):
|
|||||||
layer_state=layer_state, # adds keys to layer state
|
layer_state=layer_state, # adds keys to layer state
|
||||||
key_padding_mask=decoder_padding_mask,
|
key_padding_mask=decoder_padding_mask,
|
||||||
attn_mask=causal_mask,
|
attn_mask=causal_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
@@ -551,6 +596,7 @@ class DecoderLayer(nn.Module):
|
|||||||
key=encoder_hidden_states,
|
key=encoder_hidden_states,
|
||||||
key_padding_mask=encoder_attn_mask,
|
key_padding_mask=encoder_attn_mask,
|
||||||
layer_state=layer_state, # mutates layer state
|
layer_state=layer_state, # mutates layer state
|
||||||
|
layer_head_mask=encoder_layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
@@ -611,6 +657,8 @@ class FSMTDecoder(nn.Module):
|
|||||||
encoder_padding_mask,
|
encoder_padding_mask,
|
||||||
decoder_padding_mask,
|
decoder_padding_mask,
|
||||||
decoder_causal_mask,
|
decoder_causal_mask,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -622,12 +670,24 @@ class FSMTDecoder(nn.Module):
|
|||||||
EMNLP 2019).
|
EMNLP 2019).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids (LongTensor): previous decoder outputs of shape
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch, tgt_len)`):
|
||||||
`(batch, tgt_len)`, for teacher forcing
|
previous decoder outputs for teacher forcing
|
||||||
encoder_hidden_states: output from the encoder, used for
|
encoder_hidden_states: output from the encoder, used for
|
||||||
encoder-side attention
|
encoder-side attention
|
||||||
encoder_padding_mask: for ignoring pad tokens
|
encoder_padding_mask: for ignoring pad tokens
|
||||||
past_key_values (dict or None): dictionary used for storing state during generation
|
past_key_values (dict or None): dictionary used for storing state during generation
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseModelOutputWithPast or tuple:
|
BaseModelOutputWithPast or tuple:
|
||||||
@@ -662,6 +722,12 @@ class FSMTDecoder(nn.Module):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attns = () if output_attentions else None
|
all_cross_attns = () if output_attentions else None
|
||||||
next_decoder_cache = []
|
next_decoder_cache = []
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -681,6 +747,8 @@ class FSMTDecoder(nn.Module):
|
|||||||
decoder_padding_mask=decoder_padding_mask,
|
decoder_padding_mask=decoder_padding_mask,
|
||||||
layer_state=layer_state,
|
layer_state=layer_state,
|
||||||
causal_mask=decoder_causal_mask,
|
causal_mask=decoder_causal_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -761,6 +829,7 @@ class Attention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
layer_head_mask: Optional[Tensor] = None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||||
@@ -830,6 +899,13 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# make sure that attn_weights are included in graph
|
# make sure that attn_weights are included in graph
|
||||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -923,6 +999,8 @@ class FSMTModel(PretrainedFSMTModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs: Optional[Tuple] = None,
|
encoder_outputs: Optional[Tuple] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -958,6 +1036,7 @@ class FSMTModel(PretrainedFSMTModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -977,6 +1056,8 @@ class FSMTModel(PretrainedFSMTModel):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_padding_mask,
|
decoder_padding_mask,
|
||||||
decoder_causal_mask=causal_mask,
|
decoder_causal_mask=causal_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -1052,6 +1133,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
@@ -1080,6 +1163,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
|||||||
@@ -111,12 +111,20 @@ def prepare_fsmt_inputs_dict(
|
|||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.ne(config.pad_token_id)
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -126,7 +134,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user