🚨 Add training compatibility for Musicgen-like models (#29802)

* first modeling code

* make repository

* still WIP

* update model

* add tests

* add latest change

* clean docstrings and copied from

* update docstrings md and readme

* correct chroma function

* correct copied from and remove unreleated test

* add doc to toctree

* correct imports

* add convert script to notdoctested

* Add suggestion from Sanchit

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* correct get_uncoditional_inputs docstrings

* modify README according to SANCHIT feedback

* add chroma to audio utils

* clean librosa and torchaudio hard dependencies

* fix FE

* refactor audio decoder -> audio encoder for consistency with previous musicgen

* refactor conditional -> encoder

* modify sampling rate logics

* modify license at the beginning

* refactor all_self_attns->all_attentions

* remove ignore copy from causallm generate

* add copied from for from_sub_models

* fix make copies

* add warning if audio is truncated

* add copied from where relevant

* remove artefact

* fix convert script

* fix torchaudio and FE

* modify chroma method according to feedback-> better naming

* refactor input_values->input_features

* refactor input_values->input_features and fix import fe

* add input_features to docstrigs

* correct inputs_embeds logics

* remove dtype conversion

* refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_prepare_encoder_hidden_states_kwargs_for_generation

* change warning for chroma length

* Update src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change way to save wav, using soundfile

* correct docs and change to soundfile

* fix import

* fix init proj layers

* add draft training

* fix cross entropy

* clean loss computation

* fix labels

* remove line breaks from md

* fix issue with docstrings

* add FE suggestions

* improve is in logics and remove useless imports

* remove custom from_pretrained

* simplify docstring code

* add suggestions for modeling tests

* make style

* update converting script with sanity check

* remove encoder attention mask from conditional generation

* replace musicgen melody checkpoints with official orga

* rename ylacombe->facebook in checkpoints

* fix copies

* remove unecessary warning

* add shape in code docstrings

* add files to slow doc tests

* fix md bug and add md to not_tested

* make fix-copies

* fix hidden states test and batching

* update training code

* add training tests for melody

* add training for o.g musicgen

* fix copied from

* remove final todos

* make style

* fix style

* add suggestions from review

* add ref to the original loss computation code

* rename method + fix labels in tests

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe
2024-04-25 12:51:19 +02:00
committed by GitHub
parent ce5ae5a434
commit 90cb55bf77
5 changed files with 333 additions and 52 deletions

View File

@@ -110,8 +110,7 @@ class MusicgenDecoderTester:
parent,
batch_size=4, # need batch_size != num_hidden_layers
seq_length=7,
is_training=False,
use_labels=False,
is_training=True,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
@@ -129,7 +128,6 @@ class MusicgenDecoderTester:
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
@@ -149,7 +147,9 @@ class MusicgenDecoderTester:
config = self.get_config()
inputs_dict = prepare_musicgen_decoder_inputs_dict(
config, input_ids, encoder_hidden_states=encoder_hidden_states
config,
input_ids,
encoder_hidden_states=encoder_hidden_states,
)
return config, inputs_dict
@@ -190,6 +190,45 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
def test_config(self):
self.config_tester.run_common_tests()
# special case for labels
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = MusicgenForCausalLM(config)
model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train()
# Contrarily to the initial method, we don't unfreeze freezed parameters.
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, MusicgenForCausalLM, return_labels=True)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {MusicgenForCausalLM.__name__} has no gradient!")
# override since we have to compute the input embeddings over codebooks
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -897,6 +936,7 @@ def prepare_musicgen_inputs_dict(
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
labels=None,
):
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.reshape(
@@ -923,6 +963,7 @@ def prepare_musicgen_inputs_dict(
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"labels": labels,
}
@@ -932,8 +973,7 @@ class MusicgenTester:
parent,
batch_size=4, # need batch_size != num_hidden_layers
seq_length=7,
is_training=False,
use_labels=False,
is_training=True,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
@@ -953,7 +993,6 @@ class MusicgenTester:
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
@@ -1027,6 +1066,47 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def setUp(self):
self.model_tester = MusicgenTester(self)
# special case for labels
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train()
# The audio encoder weights are not used during the forward pass (only during the generate pass)
# So we need to freeze it to be able to train.
model.freeze_audio_encoder()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
text_encoder_config = config.text_encoder
decoder_config = config.decoder
@@ -1518,6 +1598,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.assertNotIn(config.pad_token_id, output_generate)
@unittest.skip("MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model")
def test_save_load_fast_init_from_base(self):
pass
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -2151,6 +2235,27 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.assertTrue(torch.allclose(res_eager, res_sdpa))
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.freeze_audio_encoder()
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
self.assertFalse(all(audio_encoder_grads))
self.assertTrue(all(text_encoder_grads))
model = model_class(config)
model.freeze_text_encoder()
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
self.assertTrue(all(audio_encoder_grads))
self.assertFalse(all(text_encoder_grads))
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency."""