[Qwen2Audio] handle input ids expansion during processing (#35534)

* add audio_token attribute to proc

* expand input_ids

* and legacy and expanded input_ids

* test update

* split lines

* add possibility not to provide eos and bos audio tokens

* raise errors

* test incorrect number of audio tokens

* add example

* fmt

* typo
This commit is contained in:
eustlb
2025-01-07 16:47:27 +01:00
committed by GitHub
parent 628cd838a3
commit 7f7677307c
4 changed files with 159 additions and 48 deletions

View File

@@ -34,6 +34,37 @@ The abstract from the paper is the following:
`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen) `Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
### Inference
```python
from io import BytesIO
from urllib.request import urlopen
import librosa
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True, device_map="auto")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True)
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"
audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate)
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# We can also omit the audio_bos and audio_eos tokens
prompt = "<|AUDIO|>Generate the caption in English:"
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
```
In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose. In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.
### Voice Chat Inference ### Voice Chat Inference

View File

@@ -1197,9 +1197,34 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
selected_audio_feature = audio_outputs.last_hidden_state selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.multi_modal_projector(selected_audio_feature) audio_features = self.multi_modal_projector(selected_audio_feature)
# if we have consecutive audio tokens, then it means we expanded input_ids in processing
audio_tokens = input_ids == self.config.audio_token_index
legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0
if legacy_processing:
logger.warning_once(
"Expanding inputs for audio tokens in Qwen2Audio should be done in processing."
)
inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features( inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
) )
else:
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_features_mask = torch.arange(max_audio_tokens, device=audio_output_lengths.device)[None, :]
audio_features_mask = audio_features_mask < audio_output_lengths[:, None]
audio_features = audio_features[audio_features_mask]
n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
n_audio_features = audio_features.shape[0]
if n_audio_tokens != n_audio_features:
raise ValueError(
f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
)
special_audio_mask = (input_ids == self.config.audio_token_index).to(inputs_embeds.device)
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
outputs = self.language_model( outputs = self.language_model(
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@@ -40,15 +40,32 @@ class Qwen2AudioProcessor(ProcessorMixin):
chat_template (`Optional[str]`, *optional*): chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used. is used.
audio_token (`str`, *optional*, defaults to `"<|AUDIO|>"`):
The token to use for audio tokens.
audio_bos_token (`str`, *optional*, defaults to `"<|audio_bos|>"`):
The token to use for audio bos tokens.
audio_eos_token (`str`, *optional*, defaults to `"<|audio_eos|>"`):
The token to use for audio eos tokens.
""" """
attributes = ["feature_extractor", "tokenizer"] attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "WhisperFeatureExtractor" feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None): def __init__(
self,
feature_extractor=None,
tokenizer=None,
chat_template=None,
audio_token="<|AUDIO|>",
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
):
if chat_template is None: if chat_template is None:
chat_template = self.default_chat_template chat_template = self.default_chat_template
self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
self.audio_bos_token = tokenizer.audio_bos_token if hasattr(tokenizer, "audio_bos_token") else audio_bos_token
self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token
super().__init__(feature_extractor, tokenizer, chat_template=chat_template) super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
def __call__( def __call__(
@@ -88,7 +105,18 @@ class Qwen2AudioProcessor(ProcessorMixin):
if text is None: if text is None:
raise ValueError("You need to specify either a `text` input to process.") raise ValueError("You need to specify either a `text` input to process.")
inputs = self.tokenizer(text, padding=padding, **kwargs) elif isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
# ensure we have as much audios as audio tokens
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
num_audios = 1 if type(audios) == np.ndarray else len(audios)
if num_audio_tokens != num_audios:
raise ValueError(
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
)
if audios is not None: if audios is not None:
audio_inputs = self.feature_extractor( audio_inputs = self.feature_extractor(
@@ -97,6 +125,46 @@ class Qwen2AudioProcessor(ProcessorMixin):
audio_inputs["feature_attention_mask"] = audio_inputs.pop( audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask" "attention_mask"
) # rename attention_mask to prevent conflicts later on ) # rename attention_mask to prevent conflicts later on
expanded_text = []
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()
for sample in text:
replace_str = []
while self.audio_token in sample:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
expanded_audio_token = self.audio_token * num_audio_tokens
audio_token_start_idx = sample.find(self.audio_token)
audio_token_end_idx = audio_token_start_idx + len(self.audio_token)
has_bos = (
sample[audio_token_start_idx - len(self.audio_bos_token) : audio_token_start_idx]
== self.audio_bos_token
)
has_eos = (
sample[audio_token_end_idx : audio_token_end_idx + len(self.audio_eos_token)]
== self.audio_eos_token
)
# Check if this audio token is surrounded by bos/eos tokens
if not has_bos and not has_eos:
expanded_audio_token = self.audio_bos_token + expanded_audio_token + self.audio_eos_token
replace_str.append(expanded_audio_token)
sample = sample.replace(self.audio_token, "<placeholder>", 1)
while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text
inputs = self.tokenizer(text, padding=padding, **kwargs)
if audios is not None:
inputs.update(audio_inputs) inputs.update(audio_inputs)
return BatchFeature(data={**inputs}) return BatchFeature(data={**inputs})

View File

@@ -49,7 +49,7 @@ class Qwen2AudioModelTester:
parent, parent,
ignore_index=-100, ignore_index=-100,
audio_token_index=0, audio_token_index=0,
seq_length=7, seq_length=25,
feat_seq_length=60, feat_seq_length=60,
text_config={ text_config={
"model_type": "qwen2", "model_type": "qwen2",
@@ -91,7 +91,7 @@ class Qwen2AudioModelTester:
self.is_training = is_training self.is_training = is_training
self.batch_size = 3 self.batch_size = 3
self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1 self.encoder_seq_length = seq_length
def get_config(self): def get_config(self):
return Qwen2AudioConfig( return Qwen2AudioConfig(
@@ -116,11 +116,13 @@ class Qwen2AudioModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, input_features_values, feature_attention_mask = config_and_inputs config, input_features_values, feature_attention_mask = config_and_inputs
input_length = (input_features_values.shape[-1] - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0 attention_mask[:, :1] = 0
# we are giving 3 audios let's make sure we pass in 3 audios tokens # we are giving 3 audios let's make sure we pass in 3 audios tokens
input_ids[:, 1] = config.audio_token_index input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index
inputs_dict = { inputs_dict = {
"input_features": input_features_values, "input_features": input_features_values,
"feature_attention_mask": feature_attention_mask, "feature_attention_mask": feature_attention_mask,
@@ -237,54 +239,39 @@ class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=32) output = model.generate(**inputs, max_new_tokens=32)
EXPECTED_INPUT_IDS = torch.tensor( # fmt: off
[ EXPECTED_INPUT_IDS = torch.tensor([[
[ 151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
151644, *[151646] * 101,
8948, 151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
198, ]])
2610, # fmt: on
525,
264,
10950,
17847,
13,
151645,
198,
151644,
872,
198,
14755,
220,
16,
25,
220,
151647,
151646,
151648,
198,
3838,
594,
429,
5112,
30,
151645,
198,
151644,
77091,
198,
]
]
)
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>" EXPECTED_DECODED_TEXT = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
+ "<|AUDIO|>" * 101
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
)
self.assertEqual( self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=False), self.processor.decode(output[0], skip_special_tokens=False),
EXPECTED_DECODED_TEXT, EXPECTED_DECODED_TEXT,
) )
# test the error when incorrect number of audio tokens
# fmt: off
inputs["input_ids"] = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 200,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
with self.assertRaisesRegex(
ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101"
):
model.generate(**inputs, max_new_tokens=32)
@slow @slow
def test_small_model_integration_test_batch(self): def test_small_model_integration_test_batch(self):
# Let' s make sure we test the preprocessing to replace what is used # Let' s make sure we test the preprocessing to replace what is used