[Qwen2Audio] handle input ids expansion during processing (#35534)
* add audio_token attribute to proc * expand input_ids * and legacy and expanded input_ids * test update * split lines * add possibility not to provide eos and bos audio tokens * raise errors * test incorrect number of audio tokens * add example * fmt * typo
This commit is contained in:
@@ -49,7 +49,7 @@ class Qwen2AudioModelTester:
|
||||
parent,
|
||||
ignore_index=-100,
|
||||
audio_token_index=0,
|
||||
seq_length=7,
|
||||
seq_length=25,
|
||||
feat_seq_length=60,
|
||||
text_config={
|
||||
"model_type": "qwen2",
|
||||
@@ -91,7 +91,7 @@ class Qwen2AudioModelTester:
|
||||
self.is_training = is_training
|
||||
|
||||
self.batch_size = 3
|
||||
self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1
|
||||
self.encoder_seq_length = seq_length
|
||||
|
||||
def get_config(self):
|
||||
return Qwen2AudioConfig(
|
||||
@@ -116,11 +116,13 @@ class Qwen2AudioModelTester:
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_features_values, feature_attention_mask = config_and_inputs
|
||||
input_length = (input_features_values.shape[-1] - 1) // 2 + 1
|
||||
num_audio_tokens = (input_length - 2) // 2 + 1
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
attention_mask[:, :1] = 0
|
||||
# we are giving 3 audios let's make sure we pass in 3 audios tokens
|
||||
input_ids[:, 1] = config.audio_token_index
|
||||
input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index
|
||||
inputs_dict = {
|
||||
"input_features": input_features_values,
|
||||
"feature_attention_mask": feature_attention_mask,
|
||||
@@ -237,54 +239,39 @@ class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=32)
|
||||
|
||||
EXPECTED_INPUT_IDS = torch.tensor(
|
||||
[
|
||||
[
|
||||
151644,
|
||||
8948,
|
||||
198,
|
||||
2610,
|
||||
525,
|
||||
264,
|
||||
10950,
|
||||
17847,
|
||||
13,
|
||||
151645,
|
||||
198,
|
||||
151644,
|
||||
872,
|
||||
198,
|
||||
14755,
|
||||
220,
|
||||
16,
|
||||
25,
|
||||
220,
|
||||
151647,
|
||||
151646,
|
||||
151648,
|
||||
198,
|
||||
3838,
|
||||
594,
|
||||
429,
|
||||
5112,
|
||||
30,
|
||||
151645,
|
||||
198,
|
||||
151644,
|
||||
77091,
|
||||
198,
|
||||
]
|
||||
]
|
||||
)
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_IDS = torch.tensor([[
|
||||
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
|
||||
*[151646] * 101,
|
||||
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
|
||||
]])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||
|
||||
EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
|
||||
EXPECTED_DECODED_TEXT = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
|
||||
+ "<|AUDIO|>" * 101
|
||||
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=False),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
# test the error when incorrect number of audio tokens
|
||||
# fmt: off
|
||||
inputs["input_ids"] = torch.tensor([[
|
||||
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
|
||||
*[151646] * 200,
|
||||
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
|
||||
]])
|
||||
# fmt: on
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101"
|
||||
):
|
||||
model.generate(**inputs, max_new_tokens=32)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test_batch(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
|
||||
Reference in New Issue
Block a user