diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 7bad5f98d3..972b80db7b 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -494,6 +494,8 @@ class EncoderDecoderModel(PreTrainedModel): return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index db5037eb53..1453cf9370 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -514,6 +514,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel): return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 999ba2d2db..37072270a5 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -466,6 +466,8 @@ class VisionEncoderDecoderModel(PreTrainedModel): return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] diff --git a/tests/encoder_decoder/test_modeling_encoder_decoder.py b/tests/encoder_decoder/test_modeling_encoder_decoder.py index 7e1d3b0c97..46a1bf7b68 100644 --- a/tests/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_encoder_decoder.py @@ -142,6 +142,22 @@ class EncoderDecoderMixin: outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) ) + # Test passing encoder_outputs as tuple. + encoder_outputs = (encoder_hidden_states,) + outputs_encoder_decoder = enc_dec_model( + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + def check_encoder_decoder_model_from_pretrained_using_model_paths( self, config,