diff --git a/docs/source/en/model_doc/musicgen.md b/docs/source/en/model_doc/musicgen.md index 40c4838273..d9e8429485 100644 --- a/docs/source/en/model_doc/musicgen.md +++ b/docs/source/en/model_doc/musicgen.md @@ -57,6 +57,11 @@ Generation is limited by the sinusoidal positional embeddings to 30 second input than 30 seconds of audio (1503 tokens), and input audio passed by Audio-Prompted Generation contributes to this limit so, given an input of 20 seconds of audio, MusicGen cannot generate more than 10 seconds of additional audio. +Transformers supports both mono (1-channel) and stereo (2-channel) variants of MusicGen. The mono channel versions +generate a single set of codebooks. The stereo versions generate 2 sets of codebooks, 1 for each channel (left/right), +and each set of codebooks is decoded independently through the audio compression model. The audio streams for each +channel are combined to give the final stereo output. + ### Unconditional Generation The inputs for unconditional (or 'null') generation can be obtained through the method diff --git a/src/transformers/models/musicgen/configuration_musicgen.py b/src/transformers/models/musicgen/configuration_musicgen.py index 03371e1044..e954181242 100644 --- a/src/transformers/models/musicgen/configuration_musicgen.py +++ b/src/transformers/models/musicgen/configuration_musicgen.py @@ -75,6 +75,9 @@ class MusicgenDecoderConfig(PretrainedConfig): The number of parallel codebooks forwarded to the model. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether input and output word embeddings should be tied. + audio_channels (`int`, *optional*, defaults to 1 + Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate + audio stream for the left/right output channels. Mono models generate a single audio stream output. """ model_type = "musicgen_decoder" keys_to_ignore_at_inference = ["past_key_values"] @@ -96,6 +99,7 @@ class MusicgenDecoderConfig(PretrainedConfig): initializer_factor=0.02, scale_embedding=False, num_codebooks=4, + audio_channels=1, pad_token_id=2048, bos_token_id=2048, eos_token_id=None, @@ -117,6 +121,11 @@ class MusicgenDecoderConfig(PretrainedConfig): self.use_cache = use_cache self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.num_codebooks = num_codebooks + + if audio_channels not in [1, 2]: + raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.") + self.audio_channels = audio_channels + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/musicgen/convert_musicgen_transformers.py b/src/transformers/models/musicgen/convert_musicgen_transformers.py index 517f0099d0..d4b61046e5 100644 --- a/src/transformers/models/musicgen/convert_musicgen_transformers.py +++ b/src/transformers/models/musicgen/convert_musicgen_transformers.py @@ -88,32 +88,48 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig: - if checkpoint == "small": + if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small": # default config values hidden_size = 1024 num_hidden_layers = 24 num_attention_heads = 16 - elif checkpoint == "medium": + elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium": hidden_size = 1536 num_hidden_layers = 48 num_attention_heads = 24 - elif checkpoint == "large": + elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large": hidden_size = 2048 num_hidden_layers = 48 num_attention_heads = 32 else: - raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.") + raise ValueError( + "Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, " + "or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` " + f"for the stereo checkpoints, got {checkpoint}." + ) + + if "stereo" in checkpoint: + audio_channels = 2 + num_codebooks = 8 + else: + audio_channels = 1 + num_codebooks = 4 + config = MusicgenDecoderConfig( hidden_size=hidden_size, ffn_dim=hidden_size * 4, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, + num_codebooks=num_codebooks, + audio_channels=audio_channels, ) return config @torch.no_grad() -def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"): +def convert_musicgen_checkpoint( + checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu", safe_serialization=False +): fairseq_model = MusicGen.get_pretrained(checkpoint, device=device) decoder_config = decoder_config_from_checkpoint(checkpoint) @@ -146,18 +162,20 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict) # check we can do a forward pass - input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1) - decoder_input_ids = input_ids.reshape(2 * 4, -1) + input_ids = torch.arange(0, 2 * decoder_config.num_codebooks, dtype=torch.long).reshape(2, -1) + decoder_input_ids = input_ids.reshape(2 * decoder_config.num_codebooks, -1) with torch.no_grad(): logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits - if logits.shape != (8, 1, 2048): + if logits.shape != (2 * decoder_config.num_codebooks, 1, 2048): raise ValueError("Incorrect shape for logits") # now construct the processor tokenizer = AutoTokenizer.from_pretrained("t5-base") - feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left") + feature_extractor = AutoFeatureExtractor.from_pretrained( + "facebook/encodec_32khz", padding_side="left", feature_size=decoder_config.audio_channels + ) processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) @@ -173,12 +191,12 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No if pytorch_dump_folder is not None: Path(pytorch_dump_folder).mkdir(exist_ok=True) logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}") - model.save_pretrained(pytorch_dump_folder) + model.save_pretrained(pytorch_dump_folder, safe_serialization=safe_serialization) processor.save_pretrained(pytorch_dump_folder) if repo_id: logger.info(f"Pushing model {checkpoint} to {repo_id}") - model.push_to_hub(repo_id) + model.push_to_hub(repo_id, safe_serialization=safe_serialization) processor.push_to_hub(repo_id) @@ -189,7 +207,10 @@ if __name__ == "__main__": "--checkpoint", default="small", type=str, - help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.", + help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: " + "`['small', 'medium', 'large']` for the mono checkpoints, or " + "`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` " + "for the stereo checkpoints.", ) parser.add_argument( "--pytorch_dump_folder", @@ -204,6 +225,11 @@ if __name__ == "__main__": parser.add_argument( "--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda." ) + parser.add_argument( + "--safe_serialization", + action="store_true", + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).", + ) args = parser.parse_args() convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 2a015fc032..584b29e623 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1077,21 +1077,33 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 ) + channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks # we only apply the mask if we have a large enough seq len - otherwise we return as is - if max_length < 2 * num_codebooks - 1: + if max_length < 2 * channel_codebooks - 1: return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1) # fill the shifted ids with the prompt entries, offset by the codebook idx - for codebook in range(num_codebooks): - input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] + for codebook in range(channel_codebooks): + if self.config.audio_channels == 1: + # mono channel - loop over the codebooks one-by-one + input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] + else: + # left/right channels are interleaved in the generated codebooks, so handle one then the other + input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook] + input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1] # construct a pattern mask that indicates the positions of padding tokens for each codebook # first fill the upper triangular part (the EOS padding) delay_pattern = torch.triu( - torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1 + torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1 ) # then fill the lower triangular part (the BOS padding) - delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool)) + delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool)) + + if self.config.audio_channels == 2: + # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion + delay_pattern = delay_pattern.repeat_interleave(2, dim=0) + mask = ~delay_pattern.to(input_ids.device) input_ids = mask * input_ids_shifted + ~mask * pad_token_id @@ -1856,6 +1868,11 @@ class MusicgenForConditionalGeneration(PreTrainedModel): f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " "disabled by setting `chunk_length=None` in the audio encoder." ) + + if self.config.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2: + # mono input through encodec that we convert to stereo + audio_codes = audio_codes.repeat_interleave(2, dim=2) + decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) # Decode @@ -2074,12 +2091,42 @@ class MusicgenForConditionalGeneration(PreTrainedModel): # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name encoder_kwargs["return_dict"] = True - encoder_kwargs[model_input_name] = input_values - audio_encoder_outputs = encoder.encode(**encoder_kwargs) + if self.decoder.config.audio_channels == 1: + encoder_kwargs[model_input_name] = input_values + audio_encoder_outputs = encoder.encode(**encoder_kwargs) + audio_codes = audio_encoder_outputs.audio_codes + audio_scales = audio_encoder_outputs.audio_scales - audio_codes = audio_encoder_outputs.audio_codes - frames, bsz, codebooks, seq_len = audio_codes.shape + frames, bsz, codebooks, seq_len = audio_codes.shape + + else: + if input_values.shape[1] != 2: + raise ValueError( + f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel." + ) + + encoder_kwargs[model_input_name] = input_values[:, :1, :] + audio_encoder_outputs_left = encoder.encode(**encoder_kwargs) + audio_codes_left = audio_encoder_outputs_left.audio_codes + audio_scales_left = audio_encoder_outputs_left.audio_scales + + encoder_kwargs[model_input_name] = input_values[:, 1:, :] + audio_encoder_outputs_right = encoder.encode(**encoder_kwargs) + audio_codes_right = audio_encoder_outputs_right.audio_codes + audio_scales_right = audio_encoder_outputs_right.audio_scales + + frames, bsz, codebooks, seq_len = audio_codes_left.shape + # copy alternating left/right channel codes into stereo codebook + audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len)) + + audio_codes[:, :, ::2, :] = audio_codes_left + audio_codes[:, :, 1::2, :] = audio_codes_right + + if audio_scales_left != [None] or audio_scales_right != [None]: + audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1) + else: + audio_scales = [None] * bsz if frames != 1: raise ValueError( @@ -2090,7 +2137,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) model_kwargs["decoder_input_ids"] = decoder_input_ids - model_kwargs["audio_scales"] = audio_encoder_outputs.audio_scales + model_kwargs["audio_scales"] = audio_scales return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): @@ -2433,16 +2480,25 @@ class MusicgenForConditionalGeneration(PreTrainedModel): if audio_scales is None: audio_scales = [None] * batch_size - output_values = self.audio_encoder.decode( - output_ids, - audio_scales=audio_scales, - ) + if self.decoder.config.audio_channels == 1: + output_values = self.audio_encoder.decode( + output_ids, + audio_scales=audio_scales, + ).audio_values + else: + codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales) + output_values_left = codec_outputs_left.audio_values + + codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales) + output_values_right = codec_outputs_right.audio_values + + output_values = torch.cat([output_values_left, output_values_right], dim=1) if generation_config.return_dict_in_generate: - outputs.sequences = output_values.audio_values + outputs.sequences = output_values return outputs else: - return output_values.audio_values + return output_values def get_unconditional_inputs(self, num_samples=1): """ diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 2cd662bfe5..5e1d9ccdf2 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -379,6 +379,27 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) + def test_greedy_generate_stereo_outputs(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.audio_channels = 2 + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + def prepare_musicgen_inputs_dict( config, @@ -1102,6 +1123,29 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10 ) + def test_greedy_generate_stereo_outputs(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config.audio_channels = 2 + + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + decoder_input_ids=decoder_input_ids, + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): """Produces a series of 'bip bip' sounds at a given frequency.""" @@ -1357,3 +1401,79 @@ class MusicgenIntegrationTests(unittest.TestCase): output_values.shape == (2, 1, 36480) ) # input values take shape 32000 and we generate from there self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4)) + + +@require_torch +class MusicgenStereoIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small").to(torch_device) + + @cached_property + def processor(self): + return MusicgenProcessor.from_pretrained("facebook/musicgen-stereo-small") + + @slow + def test_generate_unconditional_greedy(self): + model = self.model + + # only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same + unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=12) + + # fmt: off + EXPECTED_VALUES_LEFT = torch.tensor( + [ + 0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013, + -0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099, + ] + ) + EXPECTED_VALUES_RIGHT = torch.tensor( + [ + 0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019, + 0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096, + ] + ) + # fmt: on + + # (bsz, channels, seq_len) + self.assertTrue(output_values.shape == (1, 2, 5760)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4)) + self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4)) + + @slow + def test_generate_text_audio_prompt(self): + model = self.model + processor = self.processor + + # create stereo inputs + audio = [get_bip_bip(duration=0.5)[None, :].repeat(2, 0), get_bip_bip(duration=1.0)[None, :].repeat(2, 0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + inputs = place_dict_on_device(inputs, device=torch_device) + + output_values = model.generate(**inputs, do_sample=False, guidance_scale=3.0, max_new_tokens=12) + + # fmt: off + EXPECTED_VALUES_LEFT = torch.tensor( + [ + 0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728, + -0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430, + ] + ) + EXPECTED_VALUES_RIGHT = torch.tensor( + [ + 0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103, + -0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614, + ] + ) + # fmt: on + + # (bsz, channels, seq_len) + self.assertTrue(output_values.shape == (2, 2, 37760)) + # input values take shape 32000 and we generate from there - we check the last (generated) values + self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4)) + self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))