[MusicGen] Fix integration tests (#25169)
* move to device * update with cuda values * fix fp16 * more rigorous
This commit is contained in:
@@ -773,10 +773,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = torch.zeros((bsz, seq_len, self.d_model), device=input_ids.device)
|
||||
|
||||
for codebook in range(num_codebooks):
|
||||
inputs_embeds += self.embed_tokens[codebook](input[:, codebook])
|
||||
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
|
||||
@@ -267,8 +267,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
@@ -293,8 +293,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
@@ -324,8 +324,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
# check `generate()` and `sample()` are equal
|
||||
output_sample, output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
num_return_sequences=3,
|
||||
logits_processor=logits_processor,
|
||||
@@ -356,8 +356,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
|
||||
output_sample, output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
num_return_sequences=1,
|
||||
logits_processor=logits_processor,
|
||||
@@ -964,8 +964,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
@@ -989,8 +989,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
@@ -1019,8 +1019,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# check `generate()` and `sample()` are equal
|
||||
output_sample, output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
num_return_sequences=1,
|
||||
@@ -1050,8 +1050,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
output_sample, output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
num_return_sequences=3,
|
||||
@@ -1089,8 +1089,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
model = model_class(config).eval().to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
model.half()
|
||||
model.generate(**input_dict, max_new_tokens=10)
|
||||
model.generate(**input_dict, do_sample=True, max_new_tokens=10)
|
||||
# greedy
|
||||
model.generate(input_dict["input_ids"], attention_mask=input_dict["attention_mask"], max_new_tokens=10)
|
||||
# sampling
|
||||
model.generate(
|
||||
input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10
|
||||
)
|
||||
|
||||
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
@@ -1230,8 +1234,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: off
|
||||
EXPECTED_VALUES = torch.tensor(
|
||||
[
|
||||
0.0765, 0.0758, 0.0749, 0.0759, 0.0759, 0.0771, 0.0775, 0.0760,
|
||||
0.0762, 0.0765, 0.0767, 0.0760, 0.0738, 0.0714, 0.0713, 0.0730,
|
||||
-0.0099, -0.0140, 0.0079, 0.0080, -0.0046, 0.0065, -0.0068, -0.0185,
|
||||
0.0105, 0.0059, 0.0329, 0.0249, -0.0204, -0.0341, -0.0465, 0.0053,
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
@@ -1312,8 +1316,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: off
|
||||
EXPECTED_VALUES = torch.tensor(
|
||||
[
|
||||
-0.0047, -0.0094, -0.0028, -0.0018, -0.0057, -0.0007, -0.0104, -0.0211,
|
||||
-0.0097, -0.0150, -0.0066, -0.0004, -0.0201, -0.0325, -0.0326, -0.0098,
|
||||
-0.0111, -0.0154, 0.0047, 0.0058, -0.0068, 0.0012, -0.0109, -0.0229,
|
||||
0.0010, -0.0038, 0.0167, 0.0042, -0.0421, -0.0610, -0.0764, -0.0326,
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user