fix bug in PT speech-encoder-decoder (#15699)
* fix bug in PT speech-encoder-decoder * add pt test for `inputs is not None` * fix test * new pt test * Update tests/test_modeling_speech_encoder_decoder.py * make fixup Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -490,15 +490,16 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
|
|
||||||
if encoder_outputs is None and inputs is None:
|
if encoder_outputs is None:
|
||||||
if input_values is not None and input_features is not None:
|
if inputs is None:
|
||||||
raise ValueError("You cannot specify both input_values and input_features at the same time")
|
if input_values is not None and input_features is not None:
|
||||||
elif input_values is not None:
|
raise ValueError("You cannot specify both input_values and input_features at the same time")
|
||||||
inputs = input_values
|
elif input_values is not None:
|
||||||
elif input_features is not None:
|
inputs = input_values
|
||||||
inputs = input_features
|
elif input_features is not None:
|
||||||
else:
|
inputs = input_features
|
||||||
raise ValueError("You have to specify either input_values or input_features")
|
else:
|
||||||
|
raise ValueError("You have to specify either input_values or input_features")
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs,
|
inputs,
|
||||||
|
|||||||
@@ -125,6 +125,43 @@ class EncoderDecoderMixin:
|
|||||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_with_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
attention_mask,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
input_values=None,
|
||||||
|
input_features=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
inputs = input_values if input_features is None else input_features
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
inputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
|
)
|
||||||
|
outputs_encoder_decoder_kwarg = enc_dec_model(
|
||||||
|
inputs=inputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder_kwarg["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
|
)
|
||||||
|
|
||||||
def check_encoder_decoder_model_from_pretrained(
|
def check_encoder_decoder_model_from_pretrained(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -325,6 +362,10 @@ class EncoderDecoderMixin:
|
|||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model(**input_ids_dict)
|
self.check_encoder_decoder_model(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_with_inputs(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_with_inputs(**input_ids_dict)
|
||||||
|
|
||||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user