From db611aabee863cc5b1fdc22dcec5ce8e6c3e3b36 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:59:35 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20=F0=9F=9A=A8=20=20Raise=20error?= =?UTF-8?q?=20when=20no=20speaker=20embeddings=20in=20speecht5.=5Fgenerate?= =?UTF-8?q?=5Fspeech=20(#26418)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add warning when no speaker embeddings in speecht5._generate_speech * modify warning to error * adapt generation test --- .../models/speecht5/modeling_speecht5.py | 8 ++++++++ tests/models/speecht5/test_modeling_speecht5.py | 14 ++++++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index c4de7de090..9b8ab3d380 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2550,6 +2550,14 @@ def _generate_speech( vocoder: Optional[nn.Module] = None, output_cross_attentions: bool = False, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: + if speaker_embeddings is None: + raise ValueError( + """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following + the code snippet provided in this link: + https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors + """ + ) + encoder_attention_mask = torch.ones_like(input_values) encoder_out = model.speecht5.encoder( diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 784461eb9a..fed01a9444 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1015,15 +1015,21 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): set_seed(555) # make deterministic + speaker_embeddings = torch.zeros((1, 512)).to(torch_device) + input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device) - generated_speech = model.generate_speech(input_ids) - self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins)) + generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings) + self.assertEqual(generated_speech.shape, (228, model.config.num_mel_bins)) + + set_seed(555) # make deterministic # test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask - generated_speech_with_generate = model.generate(input_ids, attention_mask=None) - self.assertEqual(generated_speech_with_generate.shape, (1820, model.config.num_mel_bins)) + generated_speech_with_generate = model.generate( + input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings + ) + self.assertEqual(generated_speech_with_generate.shape, (228, model.config.num_mel_bins)) @require_torch