[Qwen2Audio] handle input ids expansion during processing (#35534)

* add audio_token attribute to proc

* expand input_ids

* and legacy and expanded input_ids

* test update

* split lines

* add possibility not to provide eos and bos audio tokens

* raise errors

* test incorrect number of audio tokens

* add example

* fmt

* typo
This commit is contained in:
eustlb
2025-01-07 16:47:27 +01:00
committed by GitHub
parent 628cd838a3
commit 7f7677307c
4 changed files with 159 additions and 48 deletions

View File

@@ -34,6 +34,37 @@ The abstract from the paper is the following:
`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
### Inference
```python
from io import BytesIO
from urllib.request import urlopen
import librosa
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True, device_map="auto")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True)
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"
audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate)
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# We can also omit the audio_bos and audio_eos tokens
prompt = "<|AUDIO|>Generate the caption in English:"
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
```
In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.
### Voice Chat Inference

View File

@@ -1197,9 +1197,34 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.multi_modal_projector(selected_audio_feature)
# if we have consecutive audio tokens, then it means we expanded input_ids in processing
audio_tokens = input_ids == self.config.audio_token_index
legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0
if legacy_processing:
logger.warning_once(
"Expanding inputs for audio tokens in Qwen2Audio should be done in processing."
)
inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
)
else:
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_features_mask = torch.arange(max_audio_tokens, device=audio_output_lengths.device)[None, :]
audio_features_mask = audio_features_mask < audio_output_lengths[:, None]
audio_features = audio_features[audio_features_mask]
n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
n_audio_features = audio_features.shape[0]
if n_audio_tokens != n_audio_features:
raise ValueError(
f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
)
special_audio_mask = (input_ids == self.config.audio_token_index).to(inputs_embeds.device)
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
outputs = self.language_model(
attention_mask=attention_mask,

View File

@@ -40,15 +40,32 @@ class Qwen2AudioProcessor(ProcessorMixin):
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
audio_token (`str`, *optional*, defaults to `"<|AUDIO|>"`):
The token to use for audio tokens.
audio_bos_token (`str`, *optional*, defaults to `"<|audio_bos|>"`):
The token to use for audio bos tokens.
audio_eos_token (`str`, *optional*, defaults to `"<|audio_eos|>"`):
The token to use for audio eos tokens.
"""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = "AutoTokenizer"
def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None):
def __init__(
self,
feature_extractor=None,
tokenizer=None,
chat_template=None,
audio_token="<|AUDIO|>",
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
):
if chat_template is None:
chat_template = self.default_chat_template
self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
self.audio_bos_token = tokenizer.audio_bos_token if hasattr(tokenizer, "audio_bos_token") else audio_bos_token
self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
def __call__(
@@ -88,7 +105,18 @@ class Qwen2AudioProcessor(ProcessorMixin):
if text is None:
raise ValueError("You need to specify either a `text` input to process.")
inputs = self.tokenizer(text, padding=padding, **kwargs)
elif isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
# ensure we have as much audios as audio tokens
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
num_audios = 1 if type(audios) == np.ndarray else len(audios)
if num_audio_tokens != num_audios:
raise ValueError(
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
)
if audios is not None:
audio_inputs = self.feature_extractor(
@@ -97,6 +125,46 @@ class Qwen2AudioProcessor(ProcessorMixin):
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
) # rename attention_mask to prevent conflicts later on
expanded_text = []
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()
for sample in text:
replace_str = []
while self.audio_token in sample:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
expanded_audio_token = self.audio_token * num_audio_tokens
audio_token_start_idx = sample.find(self.audio_token)
audio_token_end_idx = audio_token_start_idx + len(self.audio_token)
has_bos = (
sample[audio_token_start_idx - len(self.audio_bos_token) : audio_token_start_idx]
== self.audio_bos_token
)
has_eos = (
sample[audio_token_end_idx : audio_token_end_idx + len(self.audio_eos_token)]
== self.audio_eos_token
)
# Check if this audio token is surrounded by bos/eos tokens
if not has_bos and not has_eos:
expanded_audio_token = self.audio_bos_token + expanded_audio_token + self.audio_eos_token
replace_str.append(expanded_audio_token)
sample = sample.replace(self.audio_token, "<placeholder>", 1)
while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text
inputs = self.tokenizer(text, padding=padding, **kwargs)
if audios is not None:
inputs.update(audio_inputs)
return BatchFeature(data={**inputs})

View File

@@ -49,7 +49,7 @@ class Qwen2AudioModelTester:
parent,
ignore_index=-100,
audio_token_index=0,
seq_length=7,
seq_length=25,
feat_seq_length=60,
text_config={
"model_type": "qwen2",
@@ -91,7 +91,7 @@ class Qwen2AudioModelTester:
self.is_training = is_training
self.batch_size = 3
self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1
self.encoder_seq_length = seq_length
def get_config(self):
return Qwen2AudioConfig(
@@ -116,11 +116,13 @@ class Qwen2AudioModelTester:
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_features_values, feature_attention_mask = config_and_inputs
input_length = (input_features_values.shape[-1] - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 audios let's make sure we pass in 3 audios tokens
input_ids[:, 1] = config.audio_token_index
input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index
inputs_dict = {
"input_features": input_features_values,
"feature_attention_mask": feature_attention_mask,
@@ -237,54 +239,39 @@ class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=32)
EXPECTED_INPUT_IDS = torch.tensor(
[
[
151644,
8948,
198,
2610,
525,
264,
10950,
17847,
13,
151645,
198,
151644,
872,
198,
14755,
220,
16,
25,
220,
151647,
151646,
151648,
198,
3838,
594,
429,
5112,
30,
151645,
198,
151644,
77091,
198,
]
]
)
# fmt: off
EXPECTED_INPUT_IDS = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 101,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
EXPECTED_DECODED_TEXT = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
+ "<|AUDIO|>" * 101
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
)
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=False),
EXPECTED_DECODED_TEXT,
)
# test the error when incorrect number of audio tokens
# fmt: off
inputs["input_ids"] = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 200,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
with self.assertRaisesRegex(
ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101"
):
model.generate(**inputs, max_new_tokens=32)
@slow
def test_small_model_integration_test_batch(self):
# Let' s make sure we test the preprocessing to replace what is used