diff --git a/docs/source/en/model_doc/speecht5.md b/docs/source/en/model_doc/speecht5.md index 915fe2a9c4..4d5e2098a5 100644 --- a/docs/source/en/model_doc/speecht5.md +++ b/docs/source/en/model_doc/speecht5.md @@ -71,7 +71,7 @@ This model was contributed by [Matthijs](https://huggingface.co/Matthijs). The o [[autodoc]] SpeechT5ForTextToSpeech - forward - - generate_speech + - generate ## SpeechT5ForSpeechToSpeech diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 5a3f117adb..f4e0e8052c 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2717,7 +2717,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): >>> set_seed(555) # make deterministic >>> # generate speech - >>> speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) + >>> speech = model.generate(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) >>> speech.shape torch.Size([15872]) ``` @@ -2783,6 +2783,65 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): encoder_attentions=outputs.encoder_attentions, ) + @torch.no_grad() + def generate( + self, + input_ids: torch.LongTensor, + speaker_embeddings: Optional[torch.FloatTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + **kwargs, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: + r""" + Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a + speech waveform using a vocoder. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. The `batch_size` should be 1 currently. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) `torch.FloatTensor` + of shape `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, + input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + return _generate_speech( + self, + input_ids, + speaker_embeddings, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + ) + @torch.no_grad() def generate_speech( self, diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 9324996ffe..eaec854914 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1020,6 +1020,10 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): generated_speech = model.generate_speech(input_ids) self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins)) + # test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask + generated_speech_with_generate = model.generate(input_ids, attention_mask=None) + self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins)) + @require_torch class SpeechT5ForSpeechToSpeechTester: