From bfbe68f0352a85c0dfff49c5fb0e8296f698f46e Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 14 Oct 2019 12:04:23 +0200 Subject: [PATCH] update forward pass --- transformers/modeling_seq2seq.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index 50891ddded..e8106f47f5 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -218,12 +218,14 @@ class PreTrainedSeq2seq(nn.Module): if encoder_hidden_states is None: encoder_outputs = self.encoder(*inputs, *kwargs) encoder_hidden_states = encoder_outputs[0] + else: + encoder_outputs = (,) # Decode decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states decoder_outputs = self.decoder(**decoder_kwargs) - return decoder_outputs + return decoder_outputs + encoder_outputs class Model2Model(PreTrainedSeq2seq):