From cd6a59d5c1cb9c7905675fc82ce50df5e2bdf3f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 8 Oct 2019 10:11:02 +0200 Subject: [PATCH] add a decoder layer for Bert --- transformers/modeling_bert.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 8a2624f8f0..4011da18b4 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -341,10 +341,28 @@ class BertEncoderLayer(nn.Module): class BertDecoderLayer(nn.Module): def __init__(self, config): super(BertDecoderLayer, self).__init__() - raise NotImplementedError + self.self_attention = BertAttention(config) + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) - def forward(self, hidden_state, encoder_output): - raise NotImplementedError + def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None): + self_attention_outputs = self.self_attention(query_tensor=hidden_states, + key_tensor=hidden_states, + value_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask) + self_attention_output = self_attention_outputs[0] + attention_outputs = self.attention(query_tensor=self_attention_output, + key_tensor=encoder_outputs, + value_tensor=encoder_outputs, + attention_mask=attention_mask, + head_mask=head_mask) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + attention_outputs[1:] + return outputs class BertEncoder(nn.Module):