Allow passing encoder_ouputs as tuple to EncoderDecoder Models (#16814)
* Add passing encoder_outputs as tuple to existing test * Add check for tuple * Add check for tuple also for speech and vision Co-authored-by: jsnfly <jsnfly@gmx.de>
This commit is contained in:
@@ -142,6 +142,22 @@ class EncoderDecoderMixin:
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
# Test passing encoder_outputs as tuple.
|
||||
encoder_outputs = (encoder_hidden_states,)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained_using_model_paths(
|
||||
self,
|
||||
config,
|
||||
|
||||
Reference in New Issue
Block a user