From 075206961700fd2359f5c5cfc86a8c18d8404406 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 16 Oct 2019 16:12:22 +0200 Subject: [PATCH] adapt attention masks for the decoder case The introduction of a decoder introduces 2 changes: - We need to be able to specify a separate mask in the cross attention to mask the positions corresponding to padding tokens in the encoder state. - The self-attention in the decoder needs to be causal on top of not attending to padding tokens. --- transformers/modeling_bert.py | 66 +++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index fbf3c84646..cd9151cf62 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -198,12 +198,16 @@ class BertSelfAttention(nn.Module): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): mixed_query_layer = self.query(hidden_states) + # if the attention Module is a encoder-decoder self attention module + # they keys & values are given by the encoder; the attention mask + # needs to be such that there is no atention on the encoder's padding tokens. if encoder_hidden_states is not None: mixed_key_layer = self.key(encoder_hidden_states) mixed_value_layer = self.value(encoder_hidden_states) + attention_mask = encoder_attention_mask else: mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) @@ -284,8 +288,8 @@ class BertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): - self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states) + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): + self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -330,13 +334,13 @@ class BertLayer(nn.Module): self.intermediate = BertIntermediate(config) self.output = BertOutput(config) - def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None): + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_state is not None: - cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state) + cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights @@ -346,6 +350,7 @@ class BertLayer(nn.Module): return outputs +# NOTE I think we may need to call encoder_hidden_states[i] for each layer class BertEncoder(nn.Module): def __init__(self, config): super(BertEncoder, self).__init__() @@ -353,14 +358,14 @@ class BertEncoder(nn.Module): self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) - def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): all_hidden_states = () all_attentions = () for i, layer_module in enumerate(self.layer): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states) + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask) hidden_states = layer_outputs[0] if self.output_attentions: @@ -579,6 +584,7 @@ class BertModel(BertPreTrainedModel): """ def __init__(self, config): super(BertModel, self).__init__(config) + self.config = config self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) @@ -601,18 +607,47 @@ class BertModel(BertPreTrainedModel): self.encoder.layer[layer].attention.prune_heads(heads) def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, - head_mask=None, encoder_hidden_state=None): + head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): + """ Forward pass on the Model. + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + ever self-attention layer, following the architecture described in [1]. + + To behave like as a decoder the model needs to be initialized with the + `is_decoder` argument of the config set to `True`. An + `encoder_hidden_state` is expected as an input to the forward pass. + When a decoder, there are two kinds of attention masks to specify: + + (1) Self-attention masks that need to be causal (only attends to + previous tokens); + (2) A cross-attention mask that prevents the module + from attending to the encoder' padding tokens. + + [1] Vaswani, Ashish, et al. "Attention is all you need." Advances in + neural information processing systems. 2017. + """ if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + # we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just make it broadcastable to all heads. + if attention_mask.dims() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + + # provided a padding mask of dimensions [batch_size, seq_length] + # - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length] + # - if decoder, make it causal + if attention_mask.dims() == 2: + if self.config.is_decoder: + batch_size, seq_length = input_ids.size() + seq_ids = torch.arange(seq_length) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[None, None, :, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -641,7 +676,8 @@ class BertModel(BertPreTrainedModel): encoder_outputs = self.encoder(embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, - encoder_hidden_state=encoder_hidden_state) + encoder_hidden_state=encoder_hidden_state, + encoder_attention_mask=encoder_attention_mask) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output)