the decoder attends to the output of the encoder stack (last layer)
This commit is contained in:
@@ -288,8 +288,8 @@ class BertAttention(nn.Module):
|
|||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
self.pruned_heads = self.pruned_heads.union(heads)
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
|
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
return outputs
|
return outputs
|
||||||
@@ -334,13 +334,13 @@ class BertLayer(nn.Module):
|
|||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
if self.is_decoder and encoder_hidden_state is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
|
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
@@ -364,8 +364,7 @@ class BertEncoder(nn.Module):
|
|||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
encoder_hidden_state = encoder_hidden_states[i]
|
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
|
||||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask)
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
|
||||||
else:
|
else:
|
||||||
encoder_outputs = ()
|
encoder_outputs = ()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user