Allow passing encoder_ouputs as tuple to EncoderDecoder Models (#16814)
* Add passing encoder_outputs as tuple to existing test * Add check for tuple * Add check for tuple also for speech and vision Co-authored-by: jsnfly <jsnfly@gmx.de>
This commit is contained in:
@@ -22,7 +22,7 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...modeling_outputs import Seq2SeqLMOutput
|
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
@@ -494,6 +494,8 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
**kwargs_encoder,
|
**kwargs_encoder,
|
||||||
)
|
)
|
||||||
|
elif isinstance(encoder_outputs, tuple):
|
||||||
|
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...modeling_outputs import Seq2SeqLMOutput
|
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
@@ -514,6 +514,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
**kwargs_encoder,
|
**kwargs_encoder,
|
||||||
)
|
)
|
||||||
|
elif isinstance(encoder_outputs, tuple):
|
||||||
|
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...modeling_outputs import Seq2SeqLMOutput
|
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
@@ -466,6 +466,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
**kwargs_encoder,
|
**kwargs_encoder,
|
||||||
)
|
)
|
||||||
|
elif isinstance(encoder_outputs, tuple):
|
||||||
|
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -142,6 +142,22 @@ class EncoderDecoderMixin:
|
|||||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test passing encoder_outputs as tuple.
|
||||||
|
encoder_outputs = (encoder_hidden_states,)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
|
)
|
||||||
|
|
||||||
def check_encoder_decoder_model_from_pretrained_using_model_paths(
|
def check_encoder_decoder_model_from_pretrained_using_model_paths(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
|||||||
Reference in New Issue
Block a user