diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 0951baff7d..95d56cf6ac 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -232,7 +232,7 @@ class PreTrainedEncoderDecoder(nn.Module): encoder_outputs = () kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - decoder_outputs = self.decoder(decoder_input_ids, encoder_hidden_states, **kwargs_decoder) + decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) return decoder_outputs + encoder_outputs