From 9ca788b2e8f02ea08796e66628b1fd176245f896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 11:33:28 +0200 Subject: [PATCH] merge the two Bert layers classes --- transformers/modeling_bert.py | 54 +++++++++++++++-------------------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 89407ff8ab..f982364f5e 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -318,15 +318,26 @@ class BertOutput(nn.Module): return hidden_states -class BertEncoderLayer(nn.Module): +class BertLayer(nn.Module): def __init__(self, config): - super(BertEncoderLayer, self).__init__() - self.attention = BertAttention(config) + super(BertLayer, self).__init__() + self.self_attention = BertAttention(config) + if config.get('is_decoder', False): + self.attention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) - def forward(self, hidden_states, attention_mask=None, head_mask=None): - attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None): + self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask) + self_attention_output = self_attention_outputs[0] + + attention_outputs = self_attention_outputs + if encoder_hidden_state: + try: + attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state) + except AttributeError as ae: + raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer") + attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) @@ -334,35 +345,12 @@ class BertEncoderLayer(nn.Module): return outputs -class BertDecoderLayer(nn.Module): - def __init__(self, config): - super(BertDecoderLayer, self).__init__() - self.self_attention = BertAttention(config) - self.attention = BertDecoderAttention(config) - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) - - def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None): - self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask) - self_attention_output = self_attention_outputs[0] - attention_outputs = self.attention(query=self_attention_output, - key=encoder_outputs, - value=encoder_outputs, - attention_mask=attention_mask, - head_mask=head_mask) - attention_output = attention_outputs[0] - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - outputs = (layer_output,) + attention_outputs[1:] - return outputs - - class BertEncoder(nn.Module): def __init__(self, config): super(BertEncoder, self).__init__() self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states - self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask=None, head_mask=None): all_hidden_states = () @@ -392,9 +380,10 @@ class BertEncoder(nn.Module): class BertDecoder(nn.Module): def __init__(self, config): super(BertDecoder, self).__init__() + config["is_decoder"] = True self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states - self.layers = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None): all_hidden_states = () @@ -403,7 +392,10 @@ class BertDecoder(nn.Module): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) + layer_outputs = layer_module(hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_state=encoder_outputs) if self.output_attentions: all_attentions = all_attentions + (layer_outputs[1],)