Enrich TTS pipeline parameters naming (#26473)
* enrich TTS pipeline docstring for clearer forward_params use * change token leghts * update Pipeline parameters * correct docstring and make style * fix tests * make style * change music prompt Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * raise errors if generate_kwargs with forward-only models * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
@@ -174,6 +175,60 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
|
||||
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_forward_model_kwargs(self):
|
||||
# use vits - a forward model
|
||||
speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk", framework="pt")
|
||||
|
||||
# for reproducibility
|
||||
set_seed(555)
|
||||
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5})
|
||||
audio = outputs["audio"]
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
# assert error if generate parameter
|
||||
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5, "do_sample": True})
|
||||
|
||||
forward_params = {"speaker_id": 5}
|
||||
generate_kwargs = {"do_sample": True}
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# assert error if generate_kwargs with forward-only models
|
||||
outputs = speech_generator(
|
||||
"This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs
|
||||
)
|
||||
self.assertTrue(np.abs(outputs["audio"] - audio).max() < 1e-5)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_generative_model_kwargs(self):
|
||||
# use musicgen - a generative model
|
||||
music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
||||
|
||||
forward_params = {
|
||||
"do_sample": True,
|
||||
"max_new_tokens": 250,
|
||||
}
|
||||
|
||||
# for reproducibility
|
||||
set_seed(555)
|
||||
outputs = music_generator("This is a test", forward_params=forward_params)
|
||||
audio = outputs["audio"]
|
||||
self.assertEqual(ANY(np.ndarray), audio)
|
||||
|
||||
# make sure generate kwargs get priority over forward params
|
||||
forward_params = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 250,
|
||||
}
|
||||
generate_kwargs = {"do_sample": True}
|
||||
|
||||
# for reproducibility
|
||||
set_seed(555)
|
||||
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
|
||||
self.assertListEqual(outputs["audio"].tolist(), audio.tolist())
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
||||
return speech_generator, ["This is a test", "Another test"]
|
||||
|
||||
Reference in New Issue
Block a user