🚨 Add training compatibility for Musicgen-like models (#29802)
* first modeling code * make repository * still WIP * update model * add tests * add latest change * clean docstrings and copied from * update docstrings md and readme * correct chroma function * correct copied from and remove unreleated test * add doc to toctree * correct imports * add convert script to notdoctested * Add suggestion from Sanchit Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * correct get_uncoditional_inputs docstrings * modify README according to SANCHIT feedback * add chroma to audio utils * clean librosa and torchaudio hard dependencies * fix FE * refactor audio decoder -> audio encoder for consistency with previous musicgen * refactor conditional -> encoder * modify sampling rate logics * modify license at the beginning * refactor all_self_attns->all_attentions * remove ignore copy from causallm generate * add copied from for from_sub_models * fix make copies * add warning if audio is truncated * add copied from where relevant * remove artefact * fix convert script * fix torchaudio and FE * modify chroma method according to feedback-> better naming * refactor input_values->input_features * refactor input_values->input_features and fix import fe * add input_features to docstrigs * correct inputs_embeds logics * remove dtype conversion * refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_prepare_encoder_hidden_states_kwargs_for_generation * change warning for chroma length * Update src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change way to save wav, using soundfile * correct docs and change to soundfile * fix import * fix init proj layers * add draft training * fix cross entropy * clean loss computation * fix labels * remove line breaks from md * fix issue with docstrings * add FE suggestions * improve is in logics and remove useless imports * remove custom from_pretrained * simplify docstring code * add suggestions for modeling tests * make style * update converting script with sanity check * remove encoder attention mask from conditional generation * replace musicgen melody checkpoints with official orga * rename ylacombe->facebook in checkpoints * fix copies * remove unecessary warning * add shape in code docstrings * add files to slow doc tests * fix md bug and add md to not_tested * make fix-copies * fix hidden states test and batching * update training code * add training tests for melody * add training for o.g musicgen * fix copied from * remove final todos * make style * fix style * add suggestions from review * add ref to the original loss computation code * rename method + fix labels in tests * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -161,6 +161,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("mpt", "MptModel"),
|
("mpt", "MptModel"),
|
||||||
("mra", "MraModel"),
|
("mra", "MraModel"),
|
||||||
("mt5", "MT5Model"),
|
("mt5", "MT5Model"),
|
||||||
|
("musicgen", "MusicgenModel"),
|
||||||
|
("musicgen_melody", "MusicgenMelodyModel"),
|
||||||
("mvp", "MvpModel"),
|
("mvp", "MvpModel"),
|
||||||
("nat", "NatModel"),
|
("nat", "NatModel"),
|
||||||
("nezha", "NezhaModel"),
|
("nezha", "NezhaModel"),
|
||||||
|
|||||||
@@ -104,16 +104,17 @@ class MusicgenUnconditionalInput(ModelOutput):
|
|||||||
guidance_scale: float = None
|
guidance_scale: float = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
|
|
||||||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||||
"""
|
"""
|
||||||
Shift input ids one token to the right.
|
Shift input ids one token to the right.
|
||||||
"""
|
"""
|
||||||
|
# transpose to get (bsz, num_codebooks, seq_len)
|
||||||
|
input_ids = input_ids.transpose(1, 2)
|
||||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||||
if decoder_start_token_id is None:
|
if decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
||||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||||
|
|
||||||
if pad_token_id is None:
|
if pad_token_id is None:
|
||||||
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
||||||
@@ -909,6 +910,10 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
||||||
of `inputs_embeds`.
|
of `inputs_embeds`.
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
|
||||||
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
use_cache (`bool`, *optional*):
|
use_cache (`bool`, *optional*):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
@@ -1340,15 +1345,18 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
|
||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
Returns:
|
Returns:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (labels is not None) and (input_ids is None and inputs_embeds is None):
|
||||||
|
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@@ -1370,7 +1378,25 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
raise NotImplementedError("Training is not implemented for Musicgen.")
|
# since encoder hidden states have been concatenated to the decoder hidden states,
|
||||||
|
# we take the last timestamps corresponding to labels
|
||||||
|
logits = lm_logits[:, :, -labels.shape[1] :]
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = torch.zeros([], device=self.device)
|
||||||
|
|
||||||
|
# per codebook cross-entropy
|
||||||
|
# -100 labels are ignored
|
||||||
|
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
|
||||||
|
|
||||||
|
# per codebook cross-entropy
|
||||||
|
# ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
|
||||||
|
for codebook in range(self.config.num_codebooks):
|
||||||
|
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
|
||||||
|
codebook_labels = labels[..., codebook].contiguous().view(-1)
|
||||||
|
loss += loss_fct(codebook_logits, codebook_labels)
|
||||||
|
|
||||||
|
loss = loss / self.config.num_codebooks
|
||||||
|
|
||||||
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
|
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
|
||||||
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
|
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
|
||||||
@@ -2235,7 +2261,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
|
|
||||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif decoder_input_ids is None and decoder_inputs_embeds is None:
|
elif decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
@@ -2270,23 +2296,15 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
labels=labels,
|
||||||
**kwargs_decoder,
|
**kwargs_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
if loss is not None:
|
return decoder_outputs + encoder_outputs
|
||||||
return (loss,) + decoder_outputs + encoder_outputs
|
|
||||||
else:
|
|
||||||
return decoder_outputs + encoder_outputs
|
|
||||||
|
|
||||||
return Seq2SeqLMOutput(
|
return Seq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=decoder_outputs.loss,
|
||||||
logits=decoder_outputs.logits,
|
logits=decoder_outputs.logits,
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
@@ -2524,7 +2542,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
|
||||||
|
|
||||||
def resize_token_embeddings(self, *args, **kwargs):
|
def resize_token_embeddings(self, *args, **kwargs):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -2533,6 +2551,22 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
" model.decoder.resize_token_embeddings(...))"
|
" model.decoder.resize_token_embeddings(...))"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def freeze_audio_encoder(self):
|
||||||
|
"""
|
||||||
|
Freeze the audio encoder weights.
|
||||||
|
"""
|
||||||
|
for param in self.audio_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.audio_encoder._requires_grad = False
|
||||||
|
|
||||||
|
def freeze_text_encoder(self):
|
||||||
|
"""
|
||||||
|
Freeze the text encoder weights.
|
||||||
|
"""
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.text_encoder._requires_grad = False
|
||||||
|
|
||||||
def _maybe_initialize_input_ids_for_generation(
|
def _maybe_initialize_input_ids_for_generation(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@@ -116,16 +116,18 @@ class MusicgenMelodyOutputWithPast(ModelOutput):
|
|||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None
|
encoder_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
|
# Copied from transformers.models.musicgen.modeling_musicgen.shift_tokens_right
|
||||||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||||
"""
|
"""
|
||||||
Shift input ids one token to the right.
|
Shift input ids one token to the right.
|
||||||
"""
|
"""
|
||||||
|
# transpose to get (bsz, num_codebooks, seq_len)
|
||||||
|
input_ids = input_ids.transpose(1, 2)
|
||||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||||
if decoder_start_token_id is None:
|
if decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
||||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||||
|
|
||||||
if pad_token_id is None:
|
if pad_token_id is None:
|
||||||
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
||||||
@@ -864,7 +866,7 @@ MUSICGEN_MELODY_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
||||||
of `inputs_embeds`.
|
of `inputs_embeds`.
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
|
||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
@@ -1269,7 +1271,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, MusicgenMelodyOutputWithPast]:
|
) -> Union[Tuple, MusicgenMelodyOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
|
||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
@@ -1278,6 +1280,9 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (labels is not None) and (input_ids is None and inputs_embeds is None):
|
||||||
|
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@@ -1298,7 +1303,25 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
raise NotImplementedError("Training is not implemented for MusicgenMelody.")
|
# since encoder hidden states have been concatenated to the decoder hidden states,
|
||||||
|
# we take the last timestamps corresponding to labels
|
||||||
|
logits = lm_logits[:, :, -labels.shape[1] :]
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = torch.zeros([], device=self.device)
|
||||||
|
|
||||||
|
# per codebook cross-entropy
|
||||||
|
# ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
|
||||||
|
# -100 labels are ignored
|
||||||
|
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
|
||||||
|
|
||||||
|
# per codebook cross-entropy
|
||||||
|
for codebook in range(self.config.num_codebooks):
|
||||||
|
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
|
||||||
|
codebook_labels = labels[..., codebook].contiguous().view(-1)
|
||||||
|
loss += loss_fct(codebook_logits, codebook_labels)
|
||||||
|
|
||||||
|
loss = loss / self.config.num_codebooks
|
||||||
|
|
||||||
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
|
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
|
||||||
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
|
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
|
||||||
@@ -2156,7 +2179,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
|
|
||||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
@@ -2170,23 +2193,15 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
labels=labels,
|
||||||
**kwargs_decoder,
|
**kwargs_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
if loss is not None:
|
return decoder_outputs + (encoder_hidden_states,)
|
||||||
return (loss,) + decoder_outputs + (encoder_hidden_states,)
|
|
||||||
else:
|
|
||||||
return decoder_outputs + (encoder_hidden_states,)
|
|
||||||
|
|
||||||
return MusicgenMelodyOutputWithPast(
|
return MusicgenMelodyOutputWithPast(
|
||||||
loss=loss,
|
loss=decoder_outputs.loss,
|
||||||
logits=decoder_outputs.logits,
|
logits=decoder_outputs.logits,
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
hidden_states=decoder_outputs.hidden_states,
|
hidden_states=decoder_outputs.hidden_states,
|
||||||
@@ -2397,7 +2412,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
|
||||||
|
|
||||||
def resize_token_embeddings(self, *args, **kwargs):
|
def resize_token_embeddings(self, *args, **kwargs):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -2428,6 +2443,22 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
break
|
break
|
||||||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
||||||
|
|
||||||
|
def freeze_audio_encoder(self):
|
||||||
|
"""
|
||||||
|
Freeze the audio encoder weights.
|
||||||
|
"""
|
||||||
|
for param in self.audio_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.audio_encoder._requires_grad = False
|
||||||
|
|
||||||
|
def freeze_text_encoder(self):
|
||||||
|
"""
|
||||||
|
Freeze the text encoder weights.
|
||||||
|
"""
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.text_encoder._requires_grad = False
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -110,8 +110,7 @@ class MusicgenDecoderTester:
|
|||||||
parent,
|
parent,
|
||||||
batch_size=4, # need batch_size != num_hidden_layers
|
batch_size=4, # need batch_size != num_hidden_layers
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=False,
|
is_training=True,
|
||||||
use_labels=False,
|
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=16,
|
hidden_size=16,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
@@ -129,7 +128,6 @@ class MusicgenDecoderTester:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
@@ -149,7 +147,9 @@ class MusicgenDecoderTester:
|
|||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
inputs_dict = prepare_musicgen_decoder_inputs_dict(
|
inputs_dict = prepare_musicgen_decoder_inputs_dict(
|
||||||
config, input_ids, encoder_hidden_states=encoder_hidden_states
|
config,
|
||||||
|
input_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
)
|
)
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -190,6 +190,45 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
# special case for labels
|
||||||
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
|
if return_labels:
|
||||||
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
|
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||||
|
if not self.model_tester.is_training:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.use_cache = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = MusicgenForCausalLM(config)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# Contrarily to the initial method, we don't unfreeze freezed parameters.
|
||||||
|
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, MusicgenForCausalLM, return_labels=True)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
self.assertTrue(v.grad is not None, f"{k} in {MusicgenForCausalLM.__name__} has no gradient!")
|
||||||
|
|
||||||
# override since we have to compute the input embeddings over codebooks
|
# override since we have to compute the input embeddings over codebooks
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -897,6 +936,7 @@ def prepare_musicgen_inputs_dict(
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
|
labels=None,
|
||||||
):
|
):
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.reshape(
|
decoder_attention_mask = decoder_input_ids.reshape(
|
||||||
@@ -923,6 +963,7 @@ def prepare_musicgen_inputs_dict(
|
|||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": decoder_head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
"cross_attn_head_mask": cross_attn_head_mask,
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
|
"labels": labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -932,8 +973,7 @@ class MusicgenTester:
|
|||||||
parent,
|
parent,
|
||||||
batch_size=4, # need batch_size != num_hidden_layers
|
batch_size=4, # need batch_size != num_hidden_layers
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=False,
|
is_training=True,
|
||||||
use_labels=False,
|
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=16,
|
hidden_size=16,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
@@ -953,7 +993,6 @@ class MusicgenTester:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
@@ -1027,6 +1066,47 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = MusicgenTester(self)
|
self.model_tester = MusicgenTester(self)
|
||||||
|
|
||||||
|
# special case for labels
|
||||||
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
|
if return_labels:
|
||||||
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
|
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||||
|
if not self.model_tester.is_training:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.use_cache = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# The audio encoder weights are not used during the forward pass (only during the generate pass)
|
||||||
|
# So we need to freeze it to be able to train.
|
||||||
|
model.freeze_audio_encoder()
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
|
||||||
|
|
||||||
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
|
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
|
||||||
text_encoder_config = config.text_encoder
|
text_encoder_config = config.text_encoder
|
||||||
decoder_config = config.decoder
|
decoder_config = config.decoder
|
||||||
@@ -1518,6 +1598,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
self.assertNotIn(config.pad_token_id, output_generate)
|
||||||
|
|
||||||
|
@unittest.skip("MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model")
|
||||||
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@@ -2151,6 +2235,27 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||||
|
|
||||||
|
def test_requires_grad_with_frozen_encoders(self):
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.freeze_audio_encoder()
|
||||||
|
|
||||||
|
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||||
|
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||||
|
|
||||||
|
self.assertFalse(all(audio_encoder_grads))
|
||||||
|
self.assertTrue(all(text_encoder_grads))
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model.freeze_text_encoder()
|
||||||
|
|
||||||
|
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||||
|
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||||
|
|
||||||
|
self.assertTrue(all(audio_encoder_grads))
|
||||||
|
self.assertFalse(all(text_encoder_grads))
|
||||||
|
|
||||||
|
|
||||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||||
"""Produces a series of 'bip bip' sounds at a given frequency."""
|
"""Produces a series of 'bip bip' sounds at a given frequency."""
|
||||||
|
|||||||
@@ -109,8 +109,7 @@ class MusicgenMelodyDecoderTester:
|
|||||||
parent,
|
parent,
|
||||||
batch_size=3, # need batch_size != num_hidden_layers because of #29297
|
batch_size=3, # need batch_size != num_hidden_layers because of #29297
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=False,
|
is_training=True,
|
||||||
use_labels=False,
|
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=16,
|
hidden_size=16,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
@@ -129,7 +128,6 @@ class MusicgenMelodyDecoderTester:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
@@ -151,7 +149,9 @@ class MusicgenMelodyDecoderTester:
|
|||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
inputs_dict = prepare_musicgen_melody_decoder_inputs_dict(
|
inputs_dict = prepare_musicgen_melody_decoder_inputs_dict(
|
||||||
config, input_ids, encoder_hidden_states=encoder_hidden_states
|
config,
|
||||||
|
input_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
)
|
)
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -191,6 +191,47 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
# special case for labels
|
||||||
|
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest._prepare_for_class
|
||||||
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
|
if return_labels:
|
||||||
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
|
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.check_training_gradient_checkpointing with Musicgen->MusicgenMelody
|
||||||
|
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||||
|
if not self.model_tester.is_training:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.use_cache = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = MusicgenMelodyForCausalLM(config)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# Contrarily to the initial method, we don't unfreeze freezed parameters.
|
||||||
|
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, MusicgenMelodyForCausalLM, return_labels=True)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
self.assertTrue(v.grad is not None, f"{k} in {MusicgenMelodyForCausalLM.__name__} has no gradient!")
|
||||||
|
|
||||||
# override since we have to compute the input embeddings over codebooks
|
# override since we have to compute the input embeddings over codebooks
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -896,6 +937,7 @@ def prepare_musicgen_melody_inputs_dict(
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
labels=None,
|
||||||
):
|
):
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.reshape(
|
decoder_attention_mask = decoder_input_ids.reshape(
|
||||||
@@ -917,6 +959,7 @@ def prepare_musicgen_melody_inputs_dict(
|
|||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": decoder_head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"labels": labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -926,8 +969,7 @@ class MusicgenMelodyTester:
|
|||||||
parent,
|
parent,
|
||||||
batch_size=3, # need batch_size != num_hidden_layers because of #29297
|
batch_size=3, # need batch_size != num_hidden_layers because of #29297
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=False,
|
is_training=True,
|
||||||
use_labels=False,
|
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=16,
|
hidden_size=16,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
@@ -949,7 +991,6 @@ class MusicgenMelodyTester:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
@@ -1029,6 +1070,47 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = MusicgenMelodyTester(self)
|
self.model_tester = MusicgenMelodyTester(self)
|
||||||
|
|
||||||
|
# special case for labels
|
||||||
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
|
if return_labels:
|
||||||
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
|
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||||
|
if not self.model_tester.is_training:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.use_cache = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# The audio encoder weights are not used during the forward pass (only during the generate pass)
|
||||||
|
# So we need to freeze it to be able to train.
|
||||||
|
model.freeze_audio_encoder()
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
|
||||||
|
|
||||||
# Ignore copy
|
# Ignore copy
|
||||||
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
|
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
|
||||||
decoder_config = config.decoder
|
decoder_config = config.decoder
|
||||||
@@ -1500,6 +1582,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
self.assertNotIn(config.pad_token_id, output_generate)
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
|
||||||
|
)
|
||||||
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@@ -2133,6 +2221,27 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||||
|
|
||||||
|
def test_requires_grad_with_frozen_encoders(self):
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.freeze_audio_encoder()
|
||||||
|
|
||||||
|
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||||
|
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||||
|
|
||||||
|
self.assertFalse(all(audio_encoder_grads))
|
||||||
|
self.assertTrue(all(text_encoder_grads))
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model.freeze_text_encoder()
|
||||||
|
|
||||||
|
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||||
|
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||||
|
|
||||||
|
self.assertTrue(all(audio_encoder_grads))
|
||||||
|
self.assertFalse(all(text_encoder_grads))
|
||||||
|
|
||||||
|
|
||||||
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
|
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
|
||||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||||
|
|||||||
Reference in New Issue
Block a user