From 44a0490d3c46e62134b3fc1f0609cbdf83e571da Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:57:06 +0100 Subject: [PATCH] [MusicGen] Add sampling rate to config (#26136) * [MusicGen] Add sampling rate to config * remove tiny * make property * Update tests/pipelines/test_pipelines_text_to_audio.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/musicgen/configuration_musicgen.py | 5 +++++ .../pipelines/test_pipelines_text_to_audio.py | 21 ++++++++----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/musicgen/configuration_musicgen.py b/src/transformers/models/musicgen/configuration_musicgen.py index 315f038110..03371e1044 100644 --- a/src/transformers/models/musicgen/configuration_musicgen.py +++ b/src/transformers/models/musicgen/configuration_musicgen.py @@ -226,3 +226,8 @@ class MusicgenConfig(PretrainedConfig): decoder=decoder_config.to_dict(), **kwargs, ) + + @property + # This is a property because you might want to change the codec model on the fly + def sampling_rate(self): + return self.audio_encoder.sampling_rate diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index c7d8aa2b64..4a42122ce6 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -41,35 +41,32 @@ class TextToAudioPipelineTests(unittest.TestCase): @slow @require_torch - def test_small_model_pt(self): - speech_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") + def test_small_musicgen_pt(self): + music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") forward_params = { "do_sample": False, "max_new_tokens": 250, } - outputs = speech_generator("This is a test", forward_params=forward_params) - # musicgen sampling_rate is not straightforward to get - self.assertIsNone(outputs["sampling_rate"]) - - audio = outputs["audio"] - self.assertEqual(ANY(np.ndarray), audio) + outputs = music_generator("This is a test", forward_params=forward_params) + self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 32000}, outputs) # test two examples side-by-side - outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params) + outputs = music_generator(["This is a test", "This is a second test"], forward_params=forward_params) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) # test batching - outputs = speech_generator( + outputs = music_generator( ["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2 ) - self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) @slow @require_torch - def test_large_model_pt(self): + def test_small_bark_pt(self): speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt") forward_params = {