[MusicGen] Add sampling rate to config (#26136)
* [MusicGen] Add sampling rate to config * remove tiny * make property * Update tests/pipelines/test_pipelines_text_to_audio.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -226,3 +226,8 @@ class MusicgenConfig(PretrainedConfig):
|
|||||||
decoder=decoder_config.to_dict(),
|
decoder=decoder_config.to_dict(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
# This is a property because you might want to change the codec model on the fly
|
||||||
|
def sampling_rate(self):
|
||||||
|
return self.audio_encoder.sampling_rate
|
||||||
|
|||||||
@@ -41,35 +41,32 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt(self):
|
def test_small_musicgen_pt(self):
|
||||||
speech_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
||||||
|
|
||||||
forward_params = {
|
forward_params = {
|
||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
"max_new_tokens": 250,
|
"max_new_tokens": 250,
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
outputs = music_generator("This is a test", forward_params=forward_params)
|
||||||
# musicgen sampling_rate is not straightforward to get
|
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 32000}, outputs)
|
||||||
self.assertIsNone(outputs["sampling_rate"])
|
|
||||||
|
|
||||||
audio = outputs["audio"]
|
|
||||||
self.assertEqual(ANY(np.ndarray), audio)
|
|
||||||
|
|
||||||
# test two examples side-by-side
|
# test two examples side-by-side
|
||||||
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
outputs = music_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
||||||
audio = [output["audio"] for output in outputs]
|
audio = [output["audio"] for output in outputs]
|
||||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
# test batching
|
# test batching
|
||||||
outputs = speech_generator(
|
outputs = music_generator(
|
||||||
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
||||||
)
|
)
|
||||||
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
audio = [output["audio"] for output in outputs]
|
||||||
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_large_model_pt(self):
|
def test_small_bark_pt(self):
|
||||||
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
|
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
|
||||||
|
|
||||||
forward_params = {
|
forward_params = {
|
||||||
|
|||||||
Reference in New Issue
Block a user