From 60ba48205e4ec070780e8cb8d461421b77432bad Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 18 Feb 2022 18:20:24 +0100 Subject: [PATCH] fix bug in PT speech-encoder-decoder (#15699) * fix bug in PT speech-encoder-decoder * add pt test for `inputs is not None` * fix test * new pt test * Update tests/test_modeling_speech_encoder_decoder.py * make fixup Co-authored-by: Patrick von Platen --- .../modeling_speech_encoder_decoder.py | 19 +++++---- tests/test_modeling_speech_encoder_decoder.py | 41 +++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 3140aeed6d..0d6ecedac0 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -490,15 +490,16 @@ class SpeechEncoderDecoderModel(PreTrainedModel): argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } - if encoder_outputs is None and inputs is None: - if input_values is not None and input_features is not None: - raise ValueError("You cannot specify both input_values and input_features at the same time") - elif input_values is not None: - inputs = input_values - elif input_features is not None: - inputs = input_features - else: - raise ValueError("You have to specify either input_values or input_features") + if encoder_outputs is None: + if inputs is None: + if input_values is not None and input_features is not None: + raise ValueError("You cannot specify both input_values and input_features at the same time") + elif input_values is not None: + inputs = input_values + elif input_features is not None: + inputs = input_features + else: + raise ValueError("You have to specify either input_values or input_features") encoder_outputs = self.encoder( inputs, diff --git a/tests/test_modeling_speech_encoder_decoder.py b/tests/test_modeling_speech_encoder_decoder.py index 6a5f1b589c..4cdd46d878 100644 --- a/tests/test_modeling_speech_encoder_decoder.py +++ b/tests/test_modeling_speech_encoder_decoder.py @@ -125,6 +125,43 @@ class EncoderDecoderMixin: outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) ) + def check_encoder_decoder_model_with_inputs( + self, + config, + attention_mask, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + input_values=None, + input_features=None, + **kwargs + ): + inputs = input_values if input_features is None else input_features + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + + outputs_encoder_decoder = enc_dec_model( + inputs, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_hidden_states=True, + ) + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + outputs_encoder_decoder_kwarg = enc_dec_model( + inputs=inputs, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_hidden_states=True, + ) + self.assertEqual( + outputs_encoder_decoder_kwarg["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + def check_encoder_decoder_model_from_pretrained( self, config, @@ -325,6 +362,10 @@ class EncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model(**input_ids_dict) + def test_encoder_decoder_model_with_inputs(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_with_inputs(**input_ids_dict) + def test_encoder_decoder_model_from_pretrained_configs(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)