[Qwen2Audio] handle input ids expansion during processing (#35534)

* add audio_token attribute to proc

* expand input_ids

* and legacy and expanded input_ids

* test update

* split lines

* add possibility not to provide eos and bos audio tokens

* raise errors

* test incorrect number of audio tokens

* add example

* fmt

* typo
This commit is contained in:
eustlb
2025-01-07 16:47:27 +01:00
committed by GitHub
parent 628cd838a3
commit 7f7677307c
4 changed files with 159 additions and 48 deletions

View File

@@ -49,7 +49,7 @@ class Qwen2AudioModelTester:
parent,
ignore_index=-100,
audio_token_index=0,
seq_length=7,
seq_length=25,
feat_seq_length=60,
text_config={
"model_type": "qwen2",
@@ -91,7 +91,7 @@ class Qwen2AudioModelTester:
self.is_training = is_training
self.batch_size = 3
self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1
self.encoder_seq_length = seq_length
def get_config(self):
return Qwen2AudioConfig(
@@ -116,11 +116,13 @@ class Qwen2AudioModelTester:
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_features_values, feature_attention_mask = config_and_inputs
input_length = (input_features_values.shape[-1] - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 audios let's make sure we pass in 3 audios tokens
input_ids[:, 1] = config.audio_token_index
input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index
inputs_dict = {
"input_features": input_features_values,
"feature_attention_mask": feature_attention_mask,
@@ -237,54 +239,39 @@ class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=32)
EXPECTED_INPUT_IDS = torch.tensor(
[
[
151644,
8948,
198,
2610,
525,
264,
10950,
17847,
13,
151645,
198,
151644,
872,
198,
14755,
220,
16,
25,
220,
151647,
151646,
151648,
198,
3838,
594,
429,
5112,
30,
151645,
198,
151644,
77091,
198,
]
]
)
# fmt: off
EXPECTED_INPUT_IDS = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 101,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
EXPECTED_DECODED_TEXT = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
+ "<|AUDIO|>" * 101
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
)
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=False),
EXPECTED_DECODED_TEXT,
)
# test the error when incorrect number of audio tokens
# fmt: off
inputs["input_ids"] = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 200,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
with self.assertRaisesRegex(
ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101"
):
model.generate(**inputs, max_new_tokens=32)
@slow
def test_small_model_integration_test_batch(self):
# Let' s make sure we test the preprocessing to replace what is used