[MusicGen] Add sampling rate to config (#26136)

* [MusicGen] Add sampling rate to config

* remove tiny

* make property

* Update tests/pipelines/test_pipelines_text_to_audio.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* style

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Sanchit Gandhi
2023-09-14 16:57:06 +01:00
committed by GitHub
parent 8881f38a4f
commit 44a0490d3c
2 changed files with 14 additions and 12 deletions

View File

@@ -226,3 +226,8 @@ class MusicgenConfig(PretrainedConfig):
decoder=decoder_config.to_dict(), decoder=decoder_config.to_dict(),
**kwargs, **kwargs,
) )
@property
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate

View File

@@ -41,35 +41,32 @@ class TextToAudioPipelineTests(unittest.TestCase):
@slow @slow
@require_torch @require_torch
def test_small_model_pt(self): def test_small_musicgen_pt(self):
speech_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
forward_params = { forward_params = {
"do_sample": False, "do_sample": False,
"max_new_tokens": 250, "max_new_tokens": 250,
} }
outputs = speech_generator("This is a test", forward_params=forward_params) outputs = music_generator("This is a test", forward_params=forward_params)
# musicgen sampling_rate is not straightforward to get self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 32000}, outputs)
self.assertIsNone(outputs["sampling_rate"])
audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio)
# test two examples side-by-side # test two examples side-by-side
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params) outputs = music_generator(["This is a test", "This is a second test"], forward_params=forward_params)
audio = [output["audio"] for output in outputs] audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
# test batching # test batching
outputs = speech_generator( outputs = music_generator(
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2 ["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
) )
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
@slow @slow
@require_torch @require_torch
def test_large_model_pt(self): def test_small_bark_pt(self):
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt") speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
forward_params = { forward_params = {