[VITS] Add to TTA pipeline (#25906)
* [VITS] Add to TTA pipeline * Update tests/pipelines/test_pipelines_text_to_audio.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * remove extra spaces --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
@@ -1036,6 +1036,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
|
|||||||
# Model for Text-To-Waveform mapping
|
# Model for Text-To-Waveform mapping
|
||||||
("bark", "BarkModel"),
|
("bark", "BarkModel"),
|
||||||
("musicgen", "MusicgenForConditionalGeneration"),
|
("musicgen", "MusicgenForConditionalGeneration"),
|
||||||
|
("vits", "VitsModel"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -56,8 +56,6 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
|
raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
|
||||||
|
|
||||||
self.forward_method = self.model.generate if self.model.can_generate() else self.model
|
|
||||||
|
|
||||||
self.vocoder = None
|
self.vocoder = None
|
||||||
if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
|
if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
|
||||||
self.vocoder = (
|
self.vocoder = (
|
||||||
@@ -110,8 +108,10 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
# we expect some kwargs to be additional tensors which need to be on the right device
|
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||||
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
|
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
|
||||||
|
|
||||||
# call the generate by defaults or the forward method if the model cannot generate
|
if self.model.can_generate():
|
||||||
output = self.forward_method(**model_inputs, **kwargs)
|
output = self.model.generate(**model_inputs, **kwargs)
|
||||||
|
else:
|
||||||
|
output = self.model(**model_inputs, **kwargs)[0]
|
||||||
|
|
||||||
if self.vocoder is not None:
|
if self.vocoder is not None:
|
||||||
# in that case, the output is a spectrogram that needs to be converted into a waveform
|
# in that case, the output is a spectrogram that needs to be converted into a waveform
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from .test_pipelines_common import ANY
|
|||||||
@require_torch_or_tf
|
@require_torch_or_tf
|
||||||
class TextToAudioPipelineTests(unittest.TestCase):
|
class TextToAudioPipelineTests(unittest.TestCase):
|
||||||
model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
|
model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
|
||||||
# for now only text_to_waveform and not text_to_spectrogram
|
# for now only test text_to_waveform and not text_to_spectrogram
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -50,26 +50,21 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||||
|
|
||||||
# musicgen sampling_rate is not straightforward to get
|
# musicgen sampling_rate is not straightforward to get
|
||||||
self.assertIsNone(outputs["sampling_rate"])
|
self.assertIsNone(outputs["sampling_rate"])
|
||||||
|
|
||||||
audio = outputs["audio"]
|
audio = outputs["audio"]
|
||||||
|
|
||||||
self.assertEqual(ANY(np.ndarray), audio)
|
self.assertEqual(ANY(np.ndarray), audio)
|
||||||
|
|
||||||
# test two examples side-by-side
|
# test two examples side-by-side
|
||||||
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
||||||
|
|
||||||
audio = [output["audio"] for output in outputs]
|
audio = [output["audio"] for output in outputs]
|
||||||
|
|
||||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
# test batching
|
# test batching
|
||||||
outputs = speech_generator(
|
outputs = speech_generator(
|
||||||
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -77,8 +72,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
def test_large_model_pt(self):
|
def test_large_model_pt(self):
|
||||||
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
|
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
|
||||||
|
|
||||||
# test text-to-speech
|
|
||||||
|
|
||||||
forward_params = {
|
forward_params = {
|
||||||
# Using `do_sample=False` to force deterministic output
|
# Using `do_sample=False` to force deterministic output
|
||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
@@ -86,7 +79,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
||||||
outputs,
|
outputs,
|
||||||
@@ -97,13 +89,10 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
["This is a test", "This is a second test"],
|
["This is a test", "This is a second test"],
|
||||||
forward_params=forward_params,
|
forward_params=forward_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio = [output["audio"] for output in outputs]
|
audio = [output["audio"] for output in outputs]
|
||||||
|
|
||||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
# test other generation strategy
|
# test other generation strategy
|
||||||
|
|
||||||
forward_params = {
|
forward_params = {
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"semantic_max_new_tokens": 100,
|
"semantic_max_new_tokens": 100,
|
||||||
@@ -111,9 +100,7 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||||
|
|
||||||
audio = outputs["audio"]
|
audio = outputs["audio"]
|
||||||
|
|
||||||
self.assertEqual(ANY(np.ndarray), audio)
|
self.assertEqual(ANY(np.ndarray), audio)
|
||||||
|
|
||||||
# test using a speaker embedding
|
# test using a speaker embedding
|
||||||
@@ -127,9 +114,7 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
forward_params=forward_params,
|
forward_params=forward_params,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio = [output["audio"] for output in outputs]
|
audio = [output["audio"] for output in outputs]
|
||||||
|
|
||||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -151,7 +136,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
"return_token_type_ids": False,
|
"return_token_type_ids": False,
|
||||||
"padding": "max_length",
|
"padding": "max_length",
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs = speech_generator(
|
outputs = speech_generator(
|
||||||
"This is a test",
|
"This is a test",
|
||||||
forward_params=forward_params,
|
forward_params=forward_params,
|
||||||
@@ -163,28 +147,44 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
forward_params["history_prompt"] = history_prompt
|
forward_params["history_prompt"] = history_prompt
|
||||||
|
|
||||||
# history_prompt is a torch.Tensor passed as a forward_param
|
# history_prompt is a torch.Tensor passed as a forward_param
|
||||||
# if generation is successfull, it means that it was passed to the right device
|
# if generation is successful, it means that it was passed to the right device
|
||||||
outputs = speech_generator(
|
outputs = speech_generator(
|
||||||
"This is a test", forward_params=forward_params, preprocess_params=preprocess_params
|
"This is a test", forward_params=forward_params, preprocess_params=preprocess_params
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
||||||
outputs,
|
outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_vits_model_pt(self):
|
||||||
|
speech_generator = pipeline(task="text-to-audio", model="facebook/mms-tts-eng", framework="pt")
|
||||||
|
|
||||||
|
outputs = speech_generator("This is a test")
|
||||||
|
self.assertEqual(outputs["sampling_rate"], 16000)
|
||||||
|
|
||||||
|
audio = outputs["audio"]
|
||||||
|
self.assertEqual(ANY(np.ndarray), audio)
|
||||||
|
|
||||||
|
# test two examples side-by-side
|
||||||
|
outputs = speech_generator(["This is a test", "This is a second test"])
|
||||||
|
audio = [output["audio"] for output in outputs]
|
||||||
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
|
# test batching
|
||||||
|
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
|
||||||
|
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||||
|
|
||||||
def get_test_pipeline(self, model, tokenizer, processor):
|
def get_test_pipeline(self, model, tokenizer, processor):
|
||||||
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
||||||
return speech_generator, ["This is a test", "Another test"]
|
return speech_generator, ["This is a test", "Another test"]
|
||||||
|
|
||||||
def run_pipeline_test(self, speech_generator, _):
|
def run_pipeline_test(self, speech_generator, _):
|
||||||
outputs = speech_generator("This is a test")
|
outputs = speech_generator("This is a test")
|
||||||
|
|
||||||
self.assertEqual(ANY(np.ndarray), outputs["audio"])
|
self.assertEqual(ANY(np.ndarray), outputs["audio"])
|
||||||
|
|
||||||
forward_params = {"num_return_sequences": 2, "do_sample": True}
|
forward_params = {"num_return_sequences": 2, "do_sample": True}
|
||||||
|
|
||||||
outputs = speech_generator(["This is great !", "Something else"], forward_params=forward_params)
|
outputs = speech_generator(["This is great !", "Something else"], forward_params=forward_params)
|
||||||
audio = [output["audio"] for output in outputs]
|
audio = [output["audio"] for output in outputs]
|
||||||
|
|
||||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|||||||
Reference in New Issue
Block a user