Fix SpeechT5 decoder_attention_mask shape (#28071)

* Fix SpeechT5

* add test foward with labels and attention mask

* make style
This commit is contained in:
Yoach Lacombe
2024-06-14 15:20:11 +02:00
committed by GitHub
parent d9daeff297
commit 7e1c7dc8b6
2 changed files with 46 additions and 4 deletions

View File

@@ -909,6 +909,23 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_model_forward_with_labels(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
model = SpeechT5ForTextToSpeech(config=config).to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
decoder_attention_mask = inputs_dict["decoder_attention_mask"]
labels = inputs_dict["decoder_input_values"]
result = model(
input_ids, attention_mask=attention_mask, labels=labels, decoder_attention_mask=decoder_attention_mask
)
self.assertEqual(
result.spectrogram.shape,
(self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.num_mel_bins),
)
# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
def test_decoder_model_past_with_large_inputs(self):
pass
@@ -1436,6 +1453,23 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_model_forward_with_labels(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
model = SpeechT5ForSpeechToSpeech(config=config).to(torch_device).eval()
input_values = inputs_dict["input_values"]
attention_mask = inputs_dict["attention_mask"]
decoder_attention_mask = inputs_dict["decoder_attention_mask"]
labels = inputs_dict["decoder_input_values"]
result = model(
input_values, attention_mask=attention_mask, labels=labels, decoder_attention_mask=decoder_attention_mask
)
self.assertEqual(
result.spectrogram.shape,
(self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.num_mel_bins),
)
# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
def test_decoder_model_past_with_large_inputs(self):
pass