the decoder attends to the output of the encoder stack (last layer)

This commit is contained in:
Rémi Louf
2019-10-17 15:21:46 +02:00
parent 56e2ee4ead
commit f873a3edb2
2 changed files with 7 additions and 8 deletions

View File

@@ -165,7 +165,7 @@ class PreTrainedSeq2seq(nn.Module):
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
else:
encoder_outputs = ()