MusicGen Update (#27084)
* [MusicGen] Add stereo model * safe serialization * Update src/transformers/models/musicgen/modeling_musicgen.py * split over 2 lines * fix slow tests on cuda
This commit is contained in:
@@ -379,6 +379,27 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
|
||||
def prepare_musicgen_inputs_dict(
|
||||
config,
|
||||
@@ -1102,6 +1123,29 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10
|
||||
)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
"""Produces a series of 'bip bip' sounds at a given frequency."""
|
||||
@@ -1357,3 +1401,79 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
output_values.shape == (2, 1, 36480)
|
||||
) # input values take shape 32000 and we generate from there
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4))
|
||||
|
||||
|
||||
@require_torch
|
||||
class MusicgenStereoIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def model(self):
|
||||
return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small").to(torch_device)
|
||||
|
||||
@cached_property
|
||||
def processor(self):
|
||||
return MusicgenProcessor.from_pretrained("facebook/musicgen-stereo-small")
|
||||
|
||||
@slow
|
||||
def test_generate_unconditional_greedy(self):
|
||||
model = self.model
|
||||
|
||||
# only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same
|
||||
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
|
||||
unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
|
||||
|
||||
output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=12)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_VALUES_LEFT = torch.tensor(
|
||||
[
|
||||
0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013,
|
||||
-0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099,
|
||||
]
|
||||
)
|
||||
EXPECTED_VALUES_RIGHT = torch.tensor(
|
||||
[
|
||||
0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019,
|
||||
0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096,
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# (bsz, channels, seq_len)
|
||||
self.assertTrue(output_values.shape == (1, 2, 5760))
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_generate_text_audio_prompt(self):
|
||||
model = self.model
|
||||
processor = self.processor
|
||||
|
||||
# create stereo inputs
|
||||
audio = [get_bip_bip(duration=0.5)[None, :].repeat(2, 0), get_bip_bip(duration=1.0)[None, :].repeat(2, 0)]
|
||||
text = ["80s music", "Club techno"]
|
||||
|
||||
inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
|
||||
inputs = place_dict_on_device(inputs, device=torch_device)
|
||||
|
||||
output_values = model.generate(**inputs, do_sample=False, guidance_scale=3.0, max_new_tokens=12)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_VALUES_LEFT = torch.tensor(
|
||||
[
|
||||
0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728,
|
||||
-0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430,
|
||||
]
|
||||
)
|
||||
EXPECTED_VALUES_RIGHT = torch.tensor(
|
||||
[
|
||||
0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103,
|
||||
-0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614,
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# (bsz, channels, seq_len)
|
||||
self.assertTrue(output_values.shape == (2, 2, 37760))
|
||||
# input values take shape 32000 and we generate from there - we check the last (generated) values
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user