From a0dcefa382a541d0fecd634d6d0c3f97cd221faf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 7 Oct 2019 17:53:58 +0200 Subject: [PATCH] generalize BertSelfAttention to take separate query, key, value There is currently no way to specify the quey, key and value separately in the Attention module. However, the decoder's "encoder-decoder attention" layers take the decoder's last output as a query, the encoder's states as key and value. We thus modify the existing code so query, key and value can be added separately. This obviously poses some naming conventions; `BertSelfAttention` is not a self-attention module anymore. The way the residual is forwarded is now awkard, etc. We will need to do some refacto once the decoder is fully implemented. --- transformers/modeling_bert.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index f2e2dba589..8a2624f8f0 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -198,10 +198,10 @@ class BertSelfAttention(nn.Module): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, head_mask=None): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) + def forward(self, query, key, value, attention_mask=None, head_mask=None): + mixed_query_layer = self.query(query) + mixed_key_layer = self.key(key) + mixed_value_layer = self.value(value) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) @@ -279,9 +279,12 @@ class BertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, input_tensor, attention_mask=None, head_mask=None): - self_outputs = self.self(input_tensor, attention_mask, head_mask) - attention_output = self.output(self_outputs[0], input_tensor) + def forward(self, query_tensor, key_tensor, value_tensor, attention_mask=None, head_mask=None): + self_outputs = self.self(query_tensor, key_tensor, value_tensor, attention_mask, head_mask) + # in encoder-decoder attention we use the output of the previous decoder stage as the query + # in the Multi-Head Attention. We thus pass query_tensor as the residual in BertOutput. + # This shows the limits of the current code architecture, which may benefit from some refactoring. + attention_output = self.output(self_outputs[0], query_tensor) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -323,7 +326,11 @@ class BertEncoderLayer(nn.Module): self.output = BertOutput(config) def forward(self, hidden_states, attention_mask=None, head_mask=None): - attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + attention_outputs = self.attention(query_tensor=hidden_states, + key_tensor=hidden_states, + value_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask) attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) @@ -333,6 +340,7 @@ class BertEncoderLayer(nn.Module): class BertDecoderLayer(nn.Module): def __init__(self, config): + super(BertDecoderLayer, self).__init__() raise NotImplementedError def forward(self, hidden_state, encoder_output):