the decoder attends to the output of the encoder stack (last layer)
This commit is contained in:
@@ -165,7 +165,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user