From 51e0ebedcb183d3cd3738bba6765a261c6381552 Mon Sep 17 00:00:00 2001 From: jsnfly <37632631+jsnfly@users.noreply.github.com> Date: Mon, 18 Apr 2022 19:49:58 +0200 Subject: [PATCH] Allow passing encoder_ouputs as tuple to EncoderDecoder Models (#16814) * Add passing encoder_outputs as tuple to existing test * Add check for tuple * Add check for tuple also for speech and vision Co-authored-by: jsnfly --- .../encoder_decoder/modeling_encoder_decoder.py | 4 +++- .../modeling_speech_encoder_decoder.py | 4 +++- .../modeling_vision_encoder_decoder.py | 4 +++- .../test_modeling_encoder_decoder.py | 16 ++++++++++++++++ 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 7bad5f98d3..972b80db7b 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -494,6 +494,8 @@ class EncoderDecoderModel(PreTrainedModel): return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index db5037eb53..1453cf9370 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -514,6 +514,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel): return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 999ba2d2db..37072270a5 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -466,6 +466,8 @@ class VisionEncoderDecoderModel(PreTrainedModel): return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] diff --git a/tests/encoder_decoder/test_modeling_encoder_decoder.py b/tests/encoder_decoder/test_modeling_encoder_decoder.py index 7e1d3b0c97..46a1bf7b68 100644 --- a/tests/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_encoder_decoder.py @@ -142,6 +142,22 @@ class EncoderDecoderMixin: outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) ) + # Test passing encoder_outputs as tuple. + encoder_outputs = (encoder_hidden_states,) + outputs_encoder_decoder = enc_dec_model( + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + def check_encoder_decoder_model_from_pretrained_using_model_paths( self, config,