Add speecht5 batch generation and fix wrong attention mask when padding (#25943)
* fix speecht5 wrong attention mask when padding * enable batch generation and add parameter attention_mask * fix doc * fix format * batch postnet inputs, return batched lengths, and consistent to old api * fix format * fix format * fix the format * fix doc-builder error * add test, cross attention and docstring * optimize code based on reviews * docbuild * refine * not skip slow test * add consistent dropout for batching * loose atol * add another test regarding to the consistency of vocoder * fix format * refactor * add return_concrete_lengths as parameter for consistency w/wo batching * fix review issues * fix cross_attention issue
This commit is contained in:
@@ -1026,14 +1026,21 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_model(self):
|
||||
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
|
||||
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
|
||||
|
||||
@cached_property
|
||||
def default_vocoder(self):
|
||||
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
|
||||
|
||||
def test_generation(self):
|
||||
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
|
||||
model = self.default_model
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
|
||||
@@ -1045,7 +1052,7 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
|
||||
self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins))
|
||||
self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins))
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
@@ -1053,7 +1060,76 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
generated_speech_with_generate = model.generate(
|
||||
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins))
|
||||
self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins))
|
||||
|
||||
def test_batch_generation(self):
|
||||
model = self.default_model
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
vocoder = self.default_vocoder
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
input_text = [
|
||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||
"nor is mister quilter's manner less interesting than his matter",
|
||||
"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
|
||||
]
|
||||
inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
|
||||
|
||||
speaker_embeddings = torch.zeros((1, 512), device=torch_device)
|
||||
spectrograms, spectrogram_lengths = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
return_output_lengths=True,
|
||||
)
|
||||
self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins))
|
||||
waveforms = vocoder(spectrograms)
|
||||
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
|
||||
|
||||
# Check waveform results are the same with or without using vocder
|
||||
set_seed(555)
|
||||
waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
vocoder=vocoder,
|
||||
return_output_lengths=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8))
|
||||
self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder)
|
||||
|
||||
# Check waveform results are the same with return_concrete_lengths=True/False
|
||||
set_seed(555)
|
||||
waveforms_with_vocoder_no_lengths = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
vocoder=vocoder,
|
||||
return_output_lengths=False,
|
||||
)
|
||||
self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8))
|
||||
|
||||
# Check results when batching are consistent with results without batching
|
||||
for i, text in enumerate(input_text):
|
||||
set_seed(555) # make deterministic
|
||||
inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
|
||||
spectrogram = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape)
|
||||
self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3))
|
||||
waveform = vocoder(spectrogram)
|
||||
self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape)
|
||||
# Check whether waveforms are the same with/without passing vocoder
|
||||
set_seed(555)
|
||||
waveform_with_vocoder = model.generate_speech(
|
||||
input_ids=inputs["input_ids"],
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8))
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user