🚨 🚨 Raise error when no speaker embeddings in speecht5._generate_speech (#26418)
* add warning when no speaker embeddings in speecht5._generate_speech * modify warning to error * adapt generation test
This commit is contained in:
@@ -2550,6 +2550,14 @@ def _generate_speech(
|
|||||||
vocoder: Optional[nn.Module] = None,
|
vocoder: Optional[nn.Module] = None,
|
||||||
output_cross_attentions: bool = False,
|
output_cross_attentions: bool = False,
|
||||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
||||||
|
if speaker_embeddings is None:
|
||||||
|
raise ValueError(
|
||||||
|
"""`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following
|
||||||
|
the code snippet provided in this link:
|
||||||
|
https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
encoder_attention_mask = torch.ones_like(input_values)
|
encoder_attention_mask = torch.ones_like(input_values)
|
||||||
|
|
||||||
encoder_out = model.speecht5.encoder(
|
encoder_out = model.speecht5.encoder(
|
||||||
|
|||||||
@@ -1015,15 +1015,21 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
set_seed(555) # make deterministic
|
set_seed(555) # make deterministic
|
||||||
|
|
||||||
|
speaker_embeddings = torch.zeros((1, 512)).to(torch_device)
|
||||||
|
|
||||||
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
|
||||||
generated_speech = model.generate_speech(input_ids)
|
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
|
||||||
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
|
self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins))
|
||||||
|
|
||||||
|
set_seed(555) # make deterministic
|
||||||
|
|
||||||
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
|
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
|
||||||
generated_speech_with_generate = model.generate(input_ids, attention_mask=None)
|
generated_speech_with_generate = model.generate(
|
||||||
self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins))
|
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
|
||||||
|
)
|
||||||
|
self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user