MusicGen Update (#27084)

* [MusicGen] Add stereo model

* safe serialization

* Update src/transformers/models/musicgen/modeling_musicgen.py

* split over 2 lines

* fix slow tests on cuda
This commit is contained in:
Sanchit Gandhi
2023-11-08 13:26:02 +00:00
committed by GitHub
parent 5ef650b0ae
commit f16ff0f07e
5 changed files with 244 additions and 28 deletions

View File

@@ -379,6 +379,27 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def prepare_musicgen_inputs_dict(
config,
@@ -1102,6 +1123,29 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10
)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency."""
@@ -1357,3 +1401,79 @@ class MusicgenIntegrationTests(unittest.TestCase):
output_values.shape == (2, 1, 36480)
) # input values take shape 32000 and we generate from there
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4))
@require_torch
class MusicgenStereoIntegrationTests(unittest.TestCase):
@cached_property
def model(self):
return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small").to(torch_device)
@cached_property
def processor(self):
return MusicgenProcessor.from_pretrained("facebook/musicgen-stereo-small")
@slow
def test_generate_unconditional_greedy(self):
model = self.model
# only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=12)
# fmt: off
EXPECTED_VALUES_LEFT = torch.tensor(
[
0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013,
-0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099,
]
)
EXPECTED_VALUES_RIGHT = torch.tensor(
[
0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019,
0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096,
]
)
# fmt: on
# (bsz, channels, seq_len)
self.assertTrue(output_values.shape == (1, 2, 5760))
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
@slow
def test_generate_text_audio_prompt(self):
model = self.model
processor = self.processor
# create stereo inputs
audio = [get_bip_bip(duration=0.5)[None, :].repeat(2, 0), get_bip_bip(duration=1.0)[None, :].repeat(2, 0)]
text = ["80s music", "Club techno"]
inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
inputs = place_dict_on_device(inputs, device=torch_device)
output_values = model.generate(**inputs, do_sample=False, guidance_scale=3.0, max_new_tokens=12)
# fmt: off
EXPECTED_VALUES_LEFT = torch.tensor(
[
0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728,
-0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430,
]
)
EXPECTED_VALUES_RIGHT = torch.tensor(
[
0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103,
-0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614,
]
)
# fmt: on
# (bsz, channels, seq_len)
self.assertTrue(output_values.shape == (2, 2, 37760))
# input values take shape 32000 and we generate from there - we check the last (generated) values
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))