From 4309abedbc3b767303a07133e00af95876c4ad0b Mon Sep 17 00:00:00 2001 From: Sihan Chen <39623753+Spycsh@users.noreply.github.com> Date: Tue, 14 Nov 2023 17:54:09 +0800 Subject: [PATCH] Add speecht5 batch generation and fix wrong attention mask when padding (#25943) * fix speecht5 wrong attention mask when padding * enable batch generation and add parameter attention_mask * fix doc * fix format * batch postnet inputs, return batched lengths, and consistent to old api * fix format * fix format * fix the format * fix doc-builder error * add test, cross attention and docstring * optimize code based on reviews * docbuild * refine * not skip slow test * add consistent dropout for batching * loose atol * add another test regarding to the consistency of vocoder * fix format * refactor * add return_concrete_lengths as parameter for consistency w/wo batching * fix review issues * fix cross_attention issue --- .../models/speecht5/modeling_speecht5.py | 227 ++++++++++++++---- .../models/speecht5/test_modeling_speecht5.py | 84 ++++++- 2 files changed, 257 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 25d2b73c18..63085bc046 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -674,6 +674,11 @@ class SpeechT5SpeechDecoderPrenet(nn.Module): self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size) + def _consistent_dropout(self, inputs_embeds, p): + mask = torch.bernoulli(inputs_embeds[0], p=p) + all_masks = mask.unsqueeze(0).repeat(inputs_embeds.size(0), 1, 1) + return torch.where(all_masks == 1, inputs_embeds, 0) * 1 / (1 - p) + def forward( self, input_values: torch.Tensor, @@ -684,9 +689,7 @@ class SpeechT5SpeechDecoderPrenet(nn.Module): inputs_embeds = input_values for layer in self.layers: inputs_embeds = nn.functional.relu(layer(inputs_embeds)) - inputs_embeds = nn.functional.dropout( - inputs_embeds, self.config.speech_decoder_prenet_dropout, training=True - ) + inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout) inputs_embeds = self.final_layer(inputs_embeds) inputs_embeds = self.encode_positions(inputs_embeds) @@ -695,6 +698,7 @@ class SpeechT5SpeechDecoderPrenet(nn.Module): speaker_embeddings = nn.functional.normalize(speaker_embeddings) speaker_embeddings = speaker_embeddings.unsqueeze(1) speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1) + speaker_embeddings = speaker_embeddings.repeat(inputs_embeds.size(0), 1, 1) inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1) inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds)) @@ -2461,11 +2465,13 @@ def _generate_speech( model: SpeechT5PreTrainedModel, input_values: torch.FloatTensor, speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 20.0, vocoder: Optional[nn.Module] = None, output_cross_attentions: bool = False, + return_output_lengths: bool = False, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: if speaker_embeddings is None: raise ValueError( @@ -2475,7 +2481,12 @@ def _generate_speech( """ ) - encoder_attention_mask = torch.ones_like(input_values) + if attention_mask is None: + encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int() + else: + encoder_attention_mask = attention_mask + + bsz = input_values.size(0) encoder_out = model.speecht5.encoder( input_values=input_values, @@ -2495,19 +2506,19 @@ def _generate_speech( minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor) # Start the output sequence with a mel spectrum that is all zeros. - output_sequence = encoder_last_hidden_state.new_zeros(1, 1, model.config.num_mel_bins) + output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins) spectrogram = [] cross_attentions = [] past_key_values = None idx = 0 + result_spectrogram = {} while True: idx += 1 # Run the decoder prenet on the entire output sequence. decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) - # Run the decoder layers on the last element of the prenet output. decoder_out = model.speecht5.decoder.wrapped_decoder( hidden_states=decoder_hidden_states[:, -1:], @@ -2523,36 +2534,73 @@ def _generate_speech( if output_cross_attentions: cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) - last_decoder_output = decoder_out.last_hidden_state[0, -1] + last_decoder_output = decoder_out.last_hidden_state.squeeze(1) past_key_values = decoder_out.past_key_values # Predict the new mel spectrum for this step in the sequence. spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) - spectrum = spectrum.view(model.config.reduction_factor, model.config.num_mel_bins) + spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) spectrogram.append(spectrum) # Extend the output sequence with the new mel spectrum. - output_sequence = torch.cat((output_sequence, spectrum[-1].view(1, 1, model.config.num_mel_bins)), dim=1) - + new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins) + output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1) # Predict the probability that this is the stop token. prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) - # Finished when stop token or maximum length is reached. - if idx >= minlen and (int(sum(prob >= threshold)) > 0 or idx >= maxlen): - spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0) - spectrogram = model.speech_decoder_postnet.postnet(spectrogram) - spectrogram = spectrogram.squeeze(0) - break - - if vocoder is not None: - outputs = vocoder(spectrogram) + if idx < minlen: + continue + else: + # If the generation loop is less than maximum length time, check the ones in the batch that have met + # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch. + if idx < maxlen: + meet_thresholds = torch.sum(prob, dim=-1) >= threshold + meet_indexes = torch.where(meet_thresholds)[0].tolist() + else: + meet_indexes = range(len(prob)) + meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] + if len(meet_indexes) > 0: + spectrograms = torch.stack(spectrogram) + spectrograms = spectrograms.transpose(0, 1).flatten(1, 2) + spectrograms = model.speech_decoder_postnet.postnet(spectrograms) + for meet_index in meet_indexes: + result_spectrogram[meet_index] = spectrograms[meet_index] + if len(result_spectrogram) >= bsz: + break + spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))] + if not return_output_lengths: + spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + if vocoder is not None: + outputs = vocoder(spectrogram) + else: + outputs = spectrogram + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + if bsz > 1: + cross_attentions = cross_attentions.view( + bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] + ) + outputs = (outputs, cross_attentions) else: - outputs = spectrogram - - if output_cross_attentions: - cross_attentions = torch.cat(cross_attentions, dim=2) - outputs = (outputs, cross_attentions) - + # batched return values should also include the spectrogram/waveform lengths + spectrogram_lengths = [] + for i in range(bsz): + spectrogram_lengths.append(spectrograms[i].size(0)) + if vocoder is None: + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + outputs = (spectrograms, spectrogram_lengths) + else: + waveforms = [] + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + waveforms = vocoder(spectrograms) + waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + outputs = (waveforms, waveform_lengths) + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + cross_attentions = cross_attentions.view( + bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] + ) + outputs = (*outputs, cross_attentions) return outputs @@ -2612,7 +2660,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): ) -> Union[Tuple, Seq2SeqSpectrogramOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently. + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and [`~PreTrainedTokenizer.__call__`] for details. @@ -2719,12 +2767,14 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): def generate( self, input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, speaker_embeddings: Optional[torch.FloatTensor] = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 20.0, vocoder: Optional[nn.Module] = None, output_cross_attentions: bool = False, + return_output_lengths: bool = False, **kwargs, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: r""" @@ -2733,12 +2783,15 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently. + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and [`~PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Attention mask from the tokenizer, required for batched inference to signal to the model where to + ignore padded tokens from the input_ids. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): Tensor containing the speaker embeddings. threshold (`float`, *optional*, defaults to 0.5): @@ -2752,26 +2805,44 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): spectrogram. output_cross_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. Returns: `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: - - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape - `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. - - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape - `(num_frames,)` -- The predicted speech waveform. - - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor` - of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, - input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ return _generate_speech( self, input_ids, speaker_embeddings, + attention_mask, threshold, minlenratio, maxlenratio, vocoder, output_cross_attentions, + return_output_lengths, ) @torch.no_grad() @@ -2779,11 +2850,13 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): self, input_ids: torch.LongTensor, speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 20.0, vocoder: Optional[nn.Module] = None, output_cross_attentions: bool = False, + return_output_lengths: bool = False, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: r""" Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a @@ -2791,7 +2864,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently. + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and [`~PreTrainedTokenizer.__call__`] for details. @@ -2799,6 +2872,14 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): [What are input IDs?](../glossary#input-ids) speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): Tensor containing the speaker embeddings. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) threshold (`float`, *optional*, defaults to 0.5): The generated sequence ends when the predicted stop token probability exceeds this value. minlenratio (`float`, *optional*, defaults to 0.0): @@ -2810,26 +2891,44 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): spectrogram. output_cross_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. Returns: `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: - - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape - `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. - - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape - `(num_frames,)` -- The predicted speech waveform. - - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor` - of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, - input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ return _generate_speech( self, input_ids, speaker_embeddings, + attention_mask, threshold, minlenratio, maxlenratio, vocoder, output_cross_attentions, + return_output_lengths, ) @@ -2988,11 +3087,13 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): self, input_values: torch.FloatTensor, speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 20.0, vocoder: Optional[nn.Module] = None, output_cross_attentions: bool = False, + return_output_lengths: bool = False, ) -> torch.FloatTensor: r""" Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a @@ -3000,7 +3101,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): Args: input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Float values of input raw speech waveform. The `batch_size` should be 1 currently. + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install soundfile*). To prepare the array @@ -3008,6 +3109,14 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): Tensor containing the speaker embeddings. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) threshold (`float`, *optional*, defaults to 0.5): The generated sequence ends when the predicted stop token probability exceeds this value. minlenratio (`float`, *optional*, defaults to 0.0): @@ -3019,16 +3128,32 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): spectrogram. output_cross_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. Returns: `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: - - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape - `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. - - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape - `(num_frames,)` -- The predicted speech waveform. - - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor` - of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, - input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ if speaker_embeddings is None: speaker_embeddings = torch.zeros((1, 512), device=input_values.device) @@ -3037,11 +3162,13 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): self, input_values, speaker_embeddings, + attention_mask, threshold, minlenratio, maxlenratio, vocoder, output_cross_attentions, + return_output_lengths, ) diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 65c1a340ad..c6b4b24873 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1026,14 +1026,21 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase): @require_torch @require_sentencepiece @require_tokenizers -@slow class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): + @cached_property + def default_model(self): + return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") + @cached_property def default_processor(self): return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") + @cached_property + def default_vocoder(self): + return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + def test_generation(self): - model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") + model = self.default_model model.to(torch_device) processor = self.default_processor @@ -1045,7 +1052,7 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device) generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings) - self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins)) + self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins)) set_seed(555) # make deterministic @@ -1053,7 +1060,76 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): generated_speech_with_generate = model.generate( input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings ) - self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins)) + self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins)) + + def test_batch_generation(self): + model = self.default_model + model.to(torch_device) + processor = self.default_processor + vocoder = self.default_vocoder + set_seed(555) # make deterministic + + input_text = [ + "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", + "nor is mister quilter's manner less interesting than his matter", + "he tells us that at this festive season of the year with christmas and rosebeaf looming before us", + ] + inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) + + speaker_embeddings = torch.zeros((1, 512), device=torch_device) + spectrograms, spectrogram_lengths = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + return_output_lengths=True, + ) + self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins)) + waveforms = vocoder(spectrograms) + waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + + # Check waveform results are the same with or without using vocder + set_seed(555) + waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + vocoder=vocoder, + return_output_lengths=True, + ) + self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8)) + self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder) + + # Check waveform results are the same with return_concrete_lengths=True/False + set_seed(555) + waveforms_with_vocoder_no_lengths = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + vocoder=vocoder, + return_output_lengths=False, + ) + self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8)) + + # Check results when batching are consistent with results without batching + for i, text in enumerate(input_text): + set_seed(555) # make deterministic + inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) + spectrogram = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + ) + self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape) + self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3)) + waveform = vocoder(spectrogram) + self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape) + # Check whether waveforms are the same with/without passing vocoder + set_seed(555) + waveform_with_vocoder = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + vocoder=vocoder, + ) + self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8)) @require_torch