From f873a3edb2eba78d92e55ecb55902f7e9cb90777 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 17 Oct 2019 15:21:46 +0200 Subject: [PATCH] the decoder attends to the output of the encoder stack (last layer) --- transformers/modeling_bert.py | 13 ++++++------- transformers/modeling_seq2seq.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index be8ec5ba21..aa022bac8a 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -288,8 +288,8 @@ class BertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): - self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask) + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): + self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -334,13 +334,13 @@ class BertLayer(nn.Module): self.intermediate = BertIntermediate(config) self.output = BertOutput(config) - def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - if self.is_decoder and encoder_hidden_state is not None: - cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask) + if self.is_decoder and encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights @@ -364,8 +364,7 @@ class BertEncoder(nn.Module): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - encoder_hidden_state = encoder_hidden_states[i] - layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask) + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask) hidden_states = layer_outputs[0] if self.output_attentions: diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index ca3b9dc87a..108fdaa853 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -165,7 +165,7 @@ class PreTrainedSeq2seq(nn.Module): encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None) if encoder_hidden_states is None: encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0] + encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack* else: encoder_outputs = ()