adapt attention masks for the decoder case
The introduction of a decoder introduces 2 changes: - We need to be able to specify a separate mask in the cross attention to mask the positions corresponding to padding tokens in the encoder state. - The self-attention in the decoder needs to be causal on top of not attending to padding tokens.
This commit is contained in:
@@ -198,12 +198,16 @@ class BertSelfAttention(nn.Module):
|
|||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
return x.permute(0, 2, 1, 3)
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# if the attention Module is a encoder-decoder self attention module
|
# if the attention Module is a encoder-decoder self attention module
|
||||||
|
# they keys & values are given by the encoder; the attention mask
|
||||||
|
# needs to be such that there is no atention on the encoder's padding tokens.
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
mixed_key_layer = self.key(encoder_hidden_states)
|
mixed_key_layer = self.key(encoder_hidden_states)
|
||||||
mixed_value_layer = self.value(encoder_hidden_states)
|
mixed_value_layer = self.value(encoder_hidden_states)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
else:
|
else:
|
||||||
mixed_key_layer = self.key(hidden_states)
|
mixed_key_layer = self.key(hidden_states)
|
||||||
mixed_value_layer = self.value(hidden_states)
|
mixed_value_layer = self.value(hidden_states)
|
||||||
@@ -284,8 +288,8 @@ class BertAttention(nn.Module):
|
|||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
self.pruned_heads = self.pruned_heads.union(heads)
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states)
|
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
return outputs
|
return outputs
|
||||||
@@ -330,13 +334,13 @@ class BertLayer(nn.Module):
|
|||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
|
||||||
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
if self.is_decoder and encoder_hidden_state is not None:
|
if self.is_decoder and encoder_hidden_state is not None:
|
||||||
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
@@ -346,6 +350,7 @@ class BertLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE I think we may need to call encoder_hidden_states[i] for each layer
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertEncoder, self).__init__()
|
super(BertEncoder, self).__init__()
|
||||||
@@ -353,14 +358,14 @@ class BertEncoder(nn.Module):
|
|||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states)
|
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
@@ -579,6 +584,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertModel, self).__init__(config)
|
super(BertModel, self).__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.embeddings = BertEmbeddings(config)
|
self.embeddings = BertEmbeddings(config)
|
||||||
self.encoder = BertEncoder(config)
|
self.encoder = BertEncoder(config)
|
||||||
@@ -601,18 +607,47 @@ class BertModel(BertPreTrainedModel):
|
|||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||||
head_mask=None, encoder_hidden_state=None):
|
head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
|
||||||
|
""" Forward pass on the Model.
|
||||||
|
|
||||||
|
The model can behave as an encoder (with only self-attention) as well
|
||||||
|
as a decoder, in which case a layer of cross-attention is added between
|
||||||
|
ever self-attention layer, following the architecture described in [1].
|
||||||
|
|
||||||
|
To behave like as a decoder the model needs to be initialized with the
|
||||||
|
`is_decoder` argument of the config set to `True`. An
|
||||||
|
`encoder_hidden_state` is expected as an input to the forward pass.
|
||||||
|
When a decoder, there are two kinds of attention masks to specify:
|
||||||
|
|
||||||
|
(1) Self-attention masks that need to be causal (only attends to
|
||||||
|
previous tokens);
|
||||||
|
(2) A cross-attention mask that prevents the module
|
||||||
|
from attending to the encoder' padding tokens.
|
||||||
|
|
||||||
|
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
|
||||||
|
neural information processing systems. 2017.
|
||||||
|
"""
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros_like(input_ids)
|
token_type_ids = torch.zeros_like(input_ids)
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# ourselves in which case we just make it broadcastable to all heads.
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
if attention_mask.dims() == 3:
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
# provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
|
# - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
# - if decoder, make it causal
|
||||||
|
if attention_mask.dims() == 2:
|
||||||
|
if self.config.is_decoder:
|
||||||
|
batch_size, seq_length = input_ids.size()
|
||||||
|
seq_ids = torch.arange(seq_length)
|
||||||
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[None, None, :, :]
|
||||||
|
else:
|
||||||
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@@ -641,7 +676,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
encoder_outputs = self.encoder(embedding_output,
|
encoder_outputs = self.encoder(embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_hidden_state=encoder_hidden_state)
|
encoder_hidden_state=encoder_hidden_state,
|
||||||
|
encoder_attention_mask=encoder_attention_mask)
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user