🚨 Add training compatibility for Musicgen-like models (#29802)
* first modeling code * make repository * still WIP * update model * add tests * add latest change * clean docstrings and copied from * update docstrings md and readme * correct chroma function * correct copied from and remove unreleated test * add doc to toctree * correct imports * add convert script to notdoctested * Add suggestion from Sanchit Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * correct get_uncoditional_inputs docstrings * modify README according to SANCHIT feedback * add chroma to audio utils * clean librosa and torchaudio hard dependencies * fix FE * refactor audio decoder -> audio encoder for consistency with previous musicgen * refactor conditional -> encoder * modify sampling rate logics * modify license at the beginning * refactor all_self_attns->all_attentions * remove ignore copy from causallm generate * add copied from for from_sub_models * fix make copies * add warning if audio is truncated * add copied from where relevant * remove artefact * fix convert script * fix torchaudio and FE * modify chroma method according to feedback-> better naming * refactor input_values->input_features * refactor input_values->input_features and fix import fe * add input_features to docstrigs * correct inputs_embeds logics * remove dtype conversion * refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_prepare_encoder_hidden_states_kwargs_for_generation * change warning for chroma length * Update src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change way to save wav, using soundfile * correct docs and change to soundfile * fix import * fix init proj layers * add draft training * fix cross entropy * clean loss computation * fix labels * remove line breaks from md * fix issue with docstrings * add FE suggestions * improve is in logics and remove useless imports * remove custom from_pretrained * simplify docstring code * add suggestions for modeling tests * make style * update converting script with sanity check * remove encoder attention mask from conditional generation * replace musicgen melody checkpoints with official orga * rename ylacombe->facebook in checkpoints * fix copies * remove unecessary warning * add shape in code docstrings * add files to slow doc tests * fix md bug and add md to not_tested * make fix-copies * fix hidden states test and batching * update training code * add training tests for melody * add training for o.g musicgen * fix copied from * remove final todos * make style * fix style * add suggestions from review * add ref to the original loss computation code * rename method + fix labels in tests * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -110,8 +110,7 @@ class MusicgenDecoderTester:
|
||||
parent,
|
||||
batch_size=4, # need batch_size != num_hidden_layers
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
@@ -129,7 +128,6 @@ class MusicgenDecoderTester:
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@@ -149,7 +147,9 @@ class MusicgenDecoderTester:
|
||||
|
||||
config = self.get_config()
|
||||
inputs_dict = prepare_musicgen_decoder_inputs_dict(
|
||||
config, input_ids, encoder_hidden_states=encoder_hidden_states
|
||||
config,
|
||||
input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -190,6 +190,45 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# special case for labels
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = MusicgenForCausalLM(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||
model.train()
|
||||
|
||||
# Contrarily to the initial method, we don't unfreeze freezed parameters.
|
||||
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, MusicgenForCausalLM, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
self.assertTrue(v.grad is not None, f"{k} in {MusicgenForCausalLM.__name__} has no gradient!")
|
||||
|
||||
# override since we have to compute the input embeddings over codebooks
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -897,6 +936,7 @@ def prepare_musicgen_inputs_dict(
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.reshape(
|
||||
@@ -923,6 +963,7 @@ def prepare_musicgen_inputs_dict(
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
@@ -932,8 +973,7 @@ class MusicgenTester:
|
||||
parent,
|
||||
batch_size=4, # need batch_size != num_hidden_layers
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
@@ -953,7 +993,6 @@ class MusicgenTester:
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@@ -1027,6 +1066,47 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
def setUp(self):
|
||||
self.model_tester = MusicgenTester(self)
|
||||
|
||||
# special case for labels
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||
model.train()
|
||||
|
||||
# The audio encoder weights are not used during the forward pass (only during the generate pass)
|
||||
# So we need to freeze it to be able to train.
|
||||
model.freeze_audio_encoder()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
|
||||
|
||||
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
|
||||
text_encoder_config = config.text_encoder
|
||||
decoder_config = config.decoder
|
||||
@@ -1518,6 +1598,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@unittest.skip("MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@@ -2151,6 +2235,27 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
def test_requires_grad_with_frozen_encoders(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.freeze_audio_encoder()
|
||||
|
||||
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||
|
||||
self.assertFalse(all(audio_encoder_grads))
|
||||
self.assertTrue(all(text_encoder_grads))
|
||||
|
||||
model = model_class(config)
|
||||
model.freeze_text_encoder()
|
||||
|
||||
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||
|
||||
self.assertTrue(all(audio_encoder_grads))
|
||||
self.assertFalse(all(text_encoder_grads))
|
||||
|
||||
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
"""Produces a series of 'bip bip' sounds at a given frequency."""
|
||||
|
||||
@@ -109,8 +109,7 @@ class MusicgenMelodyDecoderTester:
|
||||
parent,
|
||||
batch_size=3, # need batch_size != num_hidden_layers because of #29297
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
@@ -129,7 +128,6 @@ class MusicgenMelodyDecoderTester:
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@@ -151,7 +149,9 @@ class MusicgenMelodyDecoderTester:
|
||||
|
||||
config = self.get_config()
|
||||
inputs_dict = prepare_musicgen_melody_decoder_inputs_dict(
|
||||
config, input_ids, encoder_hidden_states=encoder_hidden_states
|
||||
config,
|
||||
input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -191,6 +191,47 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# special case for labels
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest._prepare_for_class
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.check_training_gradient_checkpointing with Musicgen->MusicgenMelody
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = MusicgenMelodyForCausalLM(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||
model.train()
|
||||
|
||||
# Contrarily to the initial method, we don't unfreeze freezed parameters.
|
||||
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, MusicgenMelodyForCausalLM, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
self.assertTrue(v.grad is not None, f"{k} in {MusicgenMelodyForCausalLM.__name__} has no gradient!")
|
||||
|
||||
# override since we have to compute the input embeddings over codebooks
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -896,6 +937,7 @@ def prepare_musicgen_melody_inputs_dict(
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.reshape(
|
||||
@@ -917,6 +959,7 @@ def prepare_musicgen_melody_inputs_dict(
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
@@ -926,8 +969,7 @@ class MusicgenMelodyTester:
|
||||
parent,
|
||||
batch_size=3, # need batch_size != num_hidden_layers because of #29297
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
@@ -949,7 +991,6 @@ class MusicgenMelodyTester:
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@@ -1029,6 +1070,47 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def setUp(self):
|
||||
self.model_tester = MusicgenMelodyTester(self)
|
||||
|
||||
# special case for labels
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||
model.train()
|
||||
|
||||
# The audio encoder weights are not used during the forward pass (only during the generate pass)
|
||||
# So we need to freeze it to be able to train.
|
||||
model.freeze_audio_encoder()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
|
||||
|
||||
# Ignore copy
|
||||
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
|
||||
decoder_config = config.decoder
|
||||
@@ -1500,6 +1582,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@unittest.skip(
|
||||
"MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
|
||||
)
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@@ -2133,6 +2221,27 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
def test_requires_grad_with_frozen_encoders(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.freeze_audio_encoder()
|
||||
|
||||
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||
|
||||
self.assertFalse(all(audio_encoder_grads))
|
||||
self.assertTrue(all(text_encoder_grads))
|
||||
|
||||
model = model_class(config)
|
||||
model.freeze_text_encoder()
|
||||
|
||||
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||
|
||||
self.assertTrue(all(audio_encoder_grads))
|
||||
self.assertFalse(all(text_encoder_grads))
|
||||
|
||||
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
|
||||
Reference in New Issue
Block a user