🚨 🚨 Raise error when no speaker embeddings in speecht5._generate_speech (#26418)
* add warning when no speaker embeddings in speecht5._generate_speech * modify warning to error * adapt generation test
This commit is contained in:
@@ -2550,6 +2550,14 @@ def _generate_speech(
|
||||
vocoder: Optional[nn.Module] = None,
|
||||
output_cross_attentions: bool = False,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
||||
if speaker_embeddings is None:
|
||||
raise ValueError(
|
||||
"""`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following
|
||||
the code snippet provided in this link:
|
||||
https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors
|
||||
"""
|
||||
)
|
||||
|
||||
encoder_attention_mask = torch.ones_like(input_values)
|
||||
|
||||
encoder_out = model.speecht5.encoder(
|
||||
|
||||
@@ -1015,15 +1015,21 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
speaker_embeddings = torch.zeros((1, 512)).to(torch_device)
|
||||
|
||||
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
generated_speech = model.generate_speech(input_ids)
|
||||
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
|
||||
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
|
||||
self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins))
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
|
||||
generated_speech_with_generate = model.generate(input_ids, attention_mask=None)
|
||||
self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins))
|
||||
generated_speech_with_generate = model.generate(
|
||||
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins))
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user