Fix SpeechT5 decoder_attention_mask shape (#28071)
* Fix SpeechT5 * add test foward with labels and attention mask * make style
This commit is contained in:
@@ -909,6 +909,23 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_model_forward_with_labels(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
model = SpeechT5ForTextToSpeech(config=config).to(torch_device).eval()
|
||||
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
labels = inputs_dict["decoder_input_values"]
|
||||
|
||||
result = model(
|
||||
input_ids, attention_mask=attention_mask, labels=labels, decoder_attention_mask=decoder_attention_mask
|
||||
)
|
||||
self.assertEqual(
|
||||
result.spectrogram.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.num_mel_bins),
|
||||
)
|
||||
|
||||
# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
pass
|
||||
@@ -1436,6 +1453,23 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_model_forward_with_labels(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
model = SpeechT5ForSpeechToSpeech(config=config).to(torch_device).eval()
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
labels = inputs_dict["decoder_input_values"]
|
||||
|
||||
result = model(
|
||||
input_values, attention_mask=attention_mask, labels=labels, decoder_attention_mask=decoder_attention_mask
|
||||
)
|
||||
self.assertEqual(
|
||||
result.spectrogram.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.num_mel_bins),
|
||||
)
|
||||
|
||||
# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user