From 90cb55bf773d6879441616e6378d16971b557868 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Thu, 25 Apr 2024 12:51:19 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20Add=20training=20compatibility?= =?UTF-8?q?=20for=20Musicgen-like=20models=20(#29802)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * first modeling code * make repository * still WIP * update model * add tests * add latest change * clean docstrings and copied from * update docstrings md and readme * correct chroma function * correct copied from and remove unreleated test * add doc to toctree * correct imports * add convert script to notdoctested * Add suggestion from Sanchit Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * correct get_uncoditional_inputs docstrings * modify README according to SANCHIT feedback * add chroma to audio utils * clean librosa and torchaudio hard dependencies * fix FE * refactor audio decoder -> audio encoder for consistency with previous musicgen * refactor conditional -> encoder * modify sampling rate logics * modify license at the beginning * refactor all_self_attns->all_attentions * remove ignore copy from causallm generate * add copied from for from_sub_models * fix make copies * add warning if audio is truncated * add copied from where relevant * remove artefact * fix convert script * fix torchaudio and FE * modify chroma method according to feedback-> better naming * refactor input_values->input_features * refactor input_values->input_features and fix import fe * add input_features to docstrigs * correct inputs_embeds logics * remove dtype conversion * refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_prepare_encoder_hidden_states_kwargs_for_generation * change warning for chroma length * Update src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change way to save wav, using soundfile * correct docs and change to soundfile * fix import * fix init proj layers * add draft training * fix cross entropy * clean loss computation * fix labels * remove line breaks from md * fix issue with docstrings * add FE suggestions * improve is in logics and remove useless imports * remove custom from_pretrained * simplify docstring code * add suggestions for modeling tests * make style * update converting script with sanity check * remove encoder attention mask from conditional generation * replace musicgen melody checkpoints with official orga * rename ylacombe->facebook in checkpoints * fix copies * remove unecessary warning * add shape in code docstrings * add files to slow doc tests * fix md bug and add md to not_tested * make fix-copies * fix hidden states test and batching * update training code * add training tests for melody * add training for o.g musicgen * fix copied from * remove final todos * make style * fix style * add suggestions from review * add ref to the original loss computation code * rename method + fix labels in tests * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- src/transformers/models/auto/modeling_auto.py | 2 + .../models/musicgen/modeling_musicgen.py | 72 +++++++--- .../modeling_musicgen_melody.py | 69 +++++++--- .../models/musicgen/test_modeling_musicgen.py | 119 ++++++++++++++++- .../test_modeling_musicgen_melody.py | 123 +++++++++++++++++- 5 files changed, 333 insertions(+), 52 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 50b2335800..f00c223d2e 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -161,6 +161,8 @@ MODEL_MAPPING_NAMES = OrderedDict( ("mpt", "MptModel"), ("mra", "MraModel"), ("mt5", "MT5Model"), + ("musicgen", "MusicgenModel"), + ("musicgen_melody", "MusicgenMelodyModel"), ("mvp", "MvpModel"), ("nat", "NatModel"), ("nezha", "NezhaModel"), diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 7e7c7cb723..0c2f856f0e 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -104,16 +104,17 @@ class MusicgenUnconditionalInput(ModelOutput): guidance_scale: float = None -# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ + # transpose to get (bsz, num_codebooks, seq_len) + input_ids = input_ids.transpose(1, 2) shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() if decoder_start_token_id is None: raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") - shifted_input_ids[:, 0] = decoder_start_token_id + shifted_input_ids[..., 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") @@ -909,6 +910,10 @@ MUSICGEN_INPUTS_DOCSTRING = r""" If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value of `inputs_embeds`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -1340,15 +1345,18 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - Returns: + Returns: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (labels is not None) and (input_ids is None and inputs_embeds is None): + input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id) + outputs = self.model( input_ids, attention_mask=attention_mask, @@ -1370,7 +1378,25 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): loss = None if labels is not None: - raise NotImplementedError("Training is not implemented for Musicgen.") + # since encoder hidden states have been concatenated to the decoder hidden states, + # we take the last timestamps corresponding to labels + logits = lm_logits[:, :, -labels.shape[1] :] + + loss_fct = CrossEntropyLoss() + loss = torch.zeros([], device=self.device) + + # per codebook cross-entropy + # -100 labels are ignored + labels = labels.masked_fill(labels == self.config.pad_token_id, -100) + + # per codebook cross-entropy + # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243 + for codebook in range(self.config.num_codebooks): + codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) + codebook_labels = labels[..., codebook].contiguous().view(-1) + loss += loss_fct(codebook_logits, codebook_labels) + + loss = loss / self.config.num_codebooks # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) @@ -2235,7 +2261,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id + labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id ) elif decoder_input_ids is None and decoder_inputs_embeds is None: @@ -2270,23 +2296,15 @@ class MusicgenForConditionalGeneration(PreTrainedModel): use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, + labels=labels, **kwargs_decoder, ) - loss = None - if labels is not None: - logits = decoder_outputs.logits if return_dict else decoder_outputs[0] - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - if not return_dict: - if loss is not None: - return (loss,) + decoder_outputs + encoder_outputs - else: - return decoder_outputs + encoder_outputs + return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( - loss=loss, + loss=decoder_outputs.loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, @@ -2524,7 +2542,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id) def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( @@ -2533,6 +2551,22 @@ class MusicgenForConditionalGeneration(PreTrainedModel): " model.decoder.resize_token_embeddings(...))" ) + def freeze_audio_encoder(self): + """ + Freeze the audio encoder weights. + """ + for param in self.audio_encoder.parameters(): + param.requires_grad = False + self.audio_encoder._requires_grad = False + + def freeze_text_encoder(self): + """ + Freeze the text encoder weights. + """ + for param in self.text_encoder.parameters(): + param.requires_grad = False + self.text_encoder._requires_grad = False + def _maybe_initialize_input_ids_for_generation( self, inputs: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 0840635f65..867983acb7 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -116,16 +116,18 @@ class MusicgenMelodyOutputWithPast(ModelOutput): encoder_hidden_states: Optional[torch.FloatTensor] = None -# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right +# Copied from transformers.models.musicgen.modeling_musicgen.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ + # transpose to get (bsz, num_codebooks, seq_len) + input_ids = input_ids.transpose(1, 2) shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() if decoder_start_token_id is None: raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") - shifted_input_ids[:, 0] = decoder_start_token_id + shifted_input_ids[..., 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") @@ -864,7 +866,7 @@ MUSICGEN_MELODY_INPUTS_DOCSTRING = r""" If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value of `inputs_embeds`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` @@ -1269,7 +1271,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): labels: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MusicgenMelodyOutputWithPast]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` @@ -1278,6 +1280,9 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if (labels is not None) and (input_ids is None and inputs_embeds is None): + input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id) + outputs = self.model( input_ids, attention_mask=attention_mask, @@ -1298,7 +1303,25 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): loss = None if labels is not None: - raise NotImplementedError("Training is not implemented for MusicgenMelody.") + # since encoder hidden states have been concatenated to the decoder hidden states, + # we take the last timestamps corresponding to labels + logits = lm_logits[:, :, -labels.shape[1] :] + + loss_fct = CrossEntropyLoss() + loss = torch.zeros([], device=self.device) + + # per codebook cross-entropy + # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243 + # -100 labels are ignored + labels = labels.masked_fill(labels == self.config.pad_token_id, -100) + + # per codebook cross-entropy + for codebook in range(self.config.num_codebooks): + codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) + codebook_labels = labels[..., codebook].contiguous().view(-1) + loss += loss_fct(codebook_logits, codebook_labels) + + loss = loss / self.config.num_codebooks # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) @@ -2156,7 +2179,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id + labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id ) # Decode @@ -2170,23 +2193,15 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, + labels=labels, **kwargs_decoder, ) - loss = None - if labels is not None: - logits = decoder_outputs.logits if return_dict else decoder_outputs[0] - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - if not return_dict: - if loss is not None: - return (loss,) + decoder_outputs + (encoder_hidden_states,) - else: - return decoder_outputs + (encoder_hidden_states,) + return decoder_outputs + (encoder_hidden_states,) return MusicgenMelodyOutputWithPast( - loss=loss, + loss=decoder_outputs.loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, hidden_states=decoder_outputs.hidden_states, @@ -2397,7 +2412,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id) def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( @@ -2428,6 +2443,22 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): break return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + def freeze_audio_encoder(self): + """ + Freeze the audio encoder weights. + """ + for param in self.audio_encoder.parameters(): + param.requires_grad = False + self.audio_encoder._requires_grad = False + + def freeze_text_encoder(self): + """ + Freeze the text encoder weights. + """ + for param in self.text_encoder.parameters(): + param.requires_grad = False + self.text_encoder._requires_grad = False + @torch.no_grad() def generate( self, diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index d8baa6fd0c..5ac9c97479 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -110,8 +110,7 @@ class MusicgenDecoderTester: parent, batch_size=4, # need batch_size != num_hidden_layers seq_length=7, - is_training=False, - use_labels=False, + is_training=True, vocab_size=99, hidden_size=16, num_hidden_layers=2, @@ -129,7 +128,6 @@ class MusicgenDecoderTester: self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training - self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers @@ -149,7 +147,9 @@ class MusicgenDecoderTester: config = self.get_config() inputs_dict = prepare_musicgen_decoder_inputs_dict( - config, input_ids, encoder_hidden_states=encoder_hidden_states + config, + input_ids, + encoder_hidden_states=encoder_hidden_states, ) return config, inputs_dict @@ -190,6 +190,45 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste def test_config(self): self.config_tester.run_common_tests() + # special case for labels + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + if not self.model_tester.is_training: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = MusicgenForCausalLM(config) + + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() + + # Contrarily to the initial method, we don't unfreeze freezed parameters. + # Indeed, sinusoidal position embeddings have frozen weights that should stay frozen. + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + inputs = self._prepare_for_class(inputs_dict, MusicgenForCausalLM, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() + + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {MusicgenForCausalLM.__name__} has no gradient!") + # override since we have to compute the input embeddings over codebooks def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -897,6 +936,7 @@ def prepare_musicgen_inputs_dict( head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, + labels=None, ): if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.reshape( @@ -923,6 +963,7 @@ def prepare_musicgen_inputs_dict( "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, + "labels": labels, } @@ -932,8 +973,7 @@ class MusicgenTester: parent, batch_size=4, # need batch_size != num_hidden_layers seq_length=7, - is_training=False, - use_labels=False, + is_training=True, vocab_size=99, hidden_size=16, num_hidden_layers=2, @@ -953,7 +993,6 @@ class MusicgenTester: self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training - self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers @@ -1027,6 +1066,47 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def setUp(self): self.model_tester = MusicgenTester(self) + # special case for labels + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + if not self.model_tester.is_training: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = model_class(config) + + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() + + # The audio encoder weights are not used during the forward pass (only during the generate pass) + # So we need to freeze it to be able to train. + model.freeze_audio_encoder() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() + + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") + def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids): text_encoder_config = config.text_encoder decoder_config = config.decoder @@ -1518,6 +1598,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, self.assertNotIn(config.pad_token_id, output_generate) + @unittest.skip("MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model") + def test_save_load_fast_init_from_base(self): + pass + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -2151,6 +2235,27 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, self.assertTrue(torch.allclose(res_eager, res_sdpa)) + def test_requires_grad_with_frozen_encoders(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + model.freeze_audio_encoder() + + audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()] + text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()] + + self.assertFalse(all(audio_encoder_grads)) + self.assertTrue(all(text_encoder_grads)) + + model = model_class(config) + model.freeze_text_encoder() + + audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()] + text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()] + + self.assertTrue(all(audio_encoder_grads)) + self.assertFalse(all(text_encoder_grads)) + def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): """Produces a series of 'bip bip' sounds at a given frequency.""" diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index bbabd41501..98d8cc0b9f 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -109,8 +109,7 @@ class MusicgenMelodyDecoderTester: parent, batch_size=3, # need batch_size != num_hidden_layers because of #29297 seq_length=7, - is_training=False, - use_labels=False, + is_training=True, vocab_size=99, hidden_size=16, num_hidden_layers=2, @@ -129,7 +128,6 @@ class MusicgenMelodyDecoderTester: self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training - self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers @@ -151,7 +149,9 @@ class MusicgenMelodyDecoderTester: config = self.get_config() inputs_dict = prepare_musicgen_melody_decoder_inputs_dict( - config, input_ids, encoder_hidden_states=encoder_hidden_states + config, + input_ids, + encoder_hidden_states=encoder_hidden_states, ) return config, inputs_dict @@ -191,6 +191,47 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes def test_config(self): self.config_tester.run_common_tests() + # special case for labels + # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest._prepare_for_class + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.check_training_gradient_checkpointing with Musicgen->MusicgenMelody + def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + if not self.model_tester.is_training: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = MusicgenMelodyForCausalLM(config) + + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() + + # Contrarily to the initial method, we don't unfreeze freezed parameters. + # Indeed, sinusoidal position embeddings have frozen weights that should stay frozen. + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + inputs = self._prepare_for_class(inputs_dict, MusicgenMelodyForCausalLM, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() + + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {MusicgenMelodyForCausalLM.__name__} has no gradient!") + # override since we have to compute the input embeddings over codebooks def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -896,6 +937,7 @@ def prepare_musicgen_melody_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + labels=None, ): if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.reshape( @@ -917,6 +959,7 @@ def prepare_musicgen_melody_inputs_dict( "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "labels": labels, } @@ -926,8 +969,7 @@ class MusicgenMelodyTester: parent, batch_size=3, # need batch_size != num_hidden_layers because of #29297 seq_length=7, - is_training=False, - use_labels=False, + is_training=True, vocab_size=99, hidden_size=16, num_hidden_layers=2, @@ -949,7 +991,6 @@ class MusicgenMelodyTester: self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training - self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers @@ -1029,6 +1070,47 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def setUp(self): self.model_tester = MusicgenMelodyTester(self) + # special case for labels + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + if not self.model_tester.is_training: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = model_class(config) + + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() + + # The audio encoder weights are not used during the forward pass (only during the generate pass) + # So we need to freeze it to be able to train. + model.freeze_audio_encoder() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() + + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") + # Ignore copy def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids): decoder_config = config.decoder @@ -1500,6 +1582,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester self.assertNotIn(config.pad_token_id, output_generate) + @unittest.skip( + "MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model" + ) + def test_save_load_fast_init_from_base(self): + pass + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -2133,6 +2221,27 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester self.assertTrue(torch.allclose(res_eager, res_sdpa)) + def test_requires_grad_with_frozen_encoders(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + model.freeze_audio_encoder() + + audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()] + text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()] + + self.assertFalse(all(audio_encoder_grads)) + self.assertTrue(all(text_encoder_grads)) + + model = model_class(config) + model.freeze_text_encoder() + + audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()] + text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()] + + self.assertTrue(all(audio_encoder_grads)) + self.assertFalse(all(text_encoder_grads)) + # Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):