🚨 🚨 Raise error when no speaker embeddings in speecht5._generate_speech (#26418)
* add warning when no speaker embeddings in speecht5._generate_speech * modify warning to error * adapt generation test
This commit is contained in:
@@ -1015,15 +1015,21 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
speaker_embeddings = torch.zeros((1, 512)).to(torch_device)
|
||||
|
||||
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
generated_speech = model.generate_speech(input_ids)
|
||||
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
|
||||
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
|
||||
self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins))
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
|
||||
generated_speech_with_generate = model.generate(input_ids, attention_mask=None)
|
||||
self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins))
|
||||
generated_speech_with_generate = model.generate(
|
||||
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins))
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user