Enrich TTS pipeline parameters naming (#26473)

* enrich TTS pipeline docstring for clearer forward_params use

* change token leghts

* update Pipeline parameters

* correct docstring and make style

* fix tests

* make style

* change music prompt

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* raise errors if generate_kwargs with forward-only models

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe
2023-11-02 17:06:56 +00:00
committed by GitHub
parent 147e8ce4ae
commit 0ed6729bb1
2 changed files with 110 additions and 7 deletions

View File

@@ -30,6 +30,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.trainer_utils import set_seed
from .test_pipelines_common import ANY
@@ -174,6 +175,60 @@ class TextToAudioPipelineTests(unittest.TestCase):
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
@slow
@require_torch
def test_forward_model_kwargs(self):
# use vits - a forward model
speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk", framework="pt")
# for reproducibility
set_seed(555)
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5})
audio = outputs["audio"]
with self.assertRaises(TypeError):
# assert error if generate parameter
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5, "do_sample": True})
forward_params = {"speaker_id": 5}
generate_kwargs = {"do_sample": True}
with self.assertRaises(ValueError):
# assert error if generate_kwargs with forward-only models
outputs = speech_generator(
"This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs
)
self.assertTrue(np.abs(outputs["audio"] - audio).max() < 1e-5)
@slow
@require_torch
def test_generative_model_kwargs(self):
# use musicgen - a generative model
music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
forward_params = {
"do_sample": True,
"max_new_tokens": 250,
}
# for reproducibility
set_seed(555)
outputs = music_generator("This is a test", forward_params=forward_params)
audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio)
# make sure generate kwargs get priority over forward params
forward_params = {
"do_sample": False,
"max_new_tokens": 250,
}
generate_kwargs = {"do_sample": True}
# for reproducibility
set_seed(555)
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
self.assertListEqual(outputs["audio"].tolist(), audio.tolist())
def get_test_pipeline(self, model, tokenizer, processor):
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
return speech_generator, ["This is a test", "Another test"]