From 638fe7f5a4f5c3bec5b39cee374f13b4675cdb18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 17 Oct 2019 10:13:07 +0200 Subject: [PATCH] correct composition of padding and causal masks --- transformers/modeling_bert.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 2553bc0efb..05ab3395de 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -288,8 +288,8 @@ class BertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): - self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): + self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -350,7 +350,6 @@ class BertLayer(nn.Module): return outputs -# NOTE I think we may need to call encoder_hidden_states[i] for each layer class BertEncoder(nn.Module): def __init__(self, config): super(BertEncoder, self).__init__() @@ -365,7 +364,8 @@ class BertEncoder(nn.Module): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask) + encoder_hidden_state = encoder_hidden_states[i] + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask) hidden_states = layer_outputs[0] if self.output_attentions: @@ -607,22 +607,26 @@ class BertModel(BertPreTrainedModel): self.encoder.layer[layer].attention.prune_heads(heads) def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, - head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): + head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): """ Forward pass on the Model. + The values of the attention matrix (shape [batch_size, seq_length]) + should be 1.0 for the position we want to attend to and 0. for the ones + we do not want to attend to. + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between ever self-attention layer, following the architecture described in [1]. To behave like as a decoder the model needs to be initialized with the `is_decoder` argument of the config set to `True`. An - `encoder_hidden_state` is expected as an input to the forward pass. + `encoder_hidden_states` is expected as an input to the forward pass. When a decoder, there are two kinds of attention masks to specify: (1) Self-attention masks that need to be causal (only attends to previous tokens); (2) A cross-attention mask that prevents the module - from attending to the encoder' padding tokens. + from attending to the encoder's padding tokens. [1] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems. 2017. @@ -632,20 +636,20 @@ class BertModel(BertPreTrainedModel): if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) - # we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just make it broadcastable to all heads. + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] - # provided a padding mask of dimensions [batch_size, seq_length] - # - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length] - # - if decoder, make it causal + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if attention_mask.dim() == 2: if self.config.is_decoder: batch_size, seq_length = input_ids.size() seq_ids = torch.arange(seq_length) causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[None, None, :, :] + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] else: extended_attention_mask = attention_mask[:, None, None, :] @@ -676,7 +680,7 @@ class BertModel(BertPreTrainedModel): encoder_outputs = self.encoder(embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, - encoder_hidden_state=encoder_hidden_state, + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output)