[Qwen2Audio] handle input ids expansion during processing (#35534)
* add audio_token attribute to proc * expand input_ids * and legacy and expanded input_ids * test update * split lines * add possibility not to provide eos and bos audio tokens * raise errors * test incorrect number of audio tokens * add example * fmt * typo
This commit is contained in:
@@ -34,6 +34,37 @@ The abstract from the paper is the following:
|
|||||||
|
|
||||||
`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
|
`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
|
||||||
|
|
||||||
|
### Inference
|
||||||
|
|
||||||
|
```python
|
||||||
|
from io import BytesIO
|
||||||
|
from urllib.request import urlopen
|
||||||
|
import librosa
|
||||||
|
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
|
||||||
|
|
||||||
|
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True, device_map="auto")
|
||||||
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True)
|
||||||
|
|
||||||
|
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
|
||||||
|
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"
|
||||||
|
audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate)
|
||||||
|
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
generate_ids = model.generate(**inputs, max_length=256)
|
||||||
|
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||||||
|
|
||||||
|
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
|
||||||
|
# We can also omit the audio_bos and audio_eos tokens
|
||||||
|
prompt = "<|AUDIO|>Generate the caption in English:"
|
||||||
|
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
generate_ids = model.generate(**inputs, max_length=256)
|
||||||
|
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||||||
|
|
||||||
|
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
```
|
||||||
|
|
||||||
In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.
|
In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.
|
||||||
|
|
||||||
### Voice Chat Inference
|
### Voice Chat Inference
|
||||||
|
|||||||
@@ -1197,9 +1197,34 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
|||||||
selected_audio_feature = audio_outputs.last_hidden_state
|
selected_audio_feature = audio_outputs.last_hidden_state
|
||||||
audio_features = self.multi_modal_projector(selected_audio_feature)
|
audio_features = self.multi_modal_projector(selected_audio_feature)
|
||||||
|
|
||||||
inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
|
# if we have consecutive audio tokens, then it means we expanded input_ids in processing
|
||||||
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
|
audio_tokens = input_ids == self.config.audio_token_index
|
||||||
)
|
legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0
|
||||||
|
|
||||||
|
if legacy_processing:
|
||||||
|
logger.warning_once(
|
||||||
|
"Expanding inputs for audio tokens in Qwen2Audio should be done in processing."
|
||||||
|
)
|
||||||
|
inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
|
||||||
|
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_audios, max_audio_tokens, embed_dim = audio_features.shape
|
||||||
|
audio_features_mask = torch.arange(max_audio_tokens, device=audio_output_lengths.device)[None, :]
|
||||||
|
audio_features_mask = audio_features_mask < audio_output_lengths[:, None]
|
||||||
|
audio_features = audio_features[audio_features_mask]
|
||||||
|
|
||||||
|
n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
|
||||||
|
n_audio_features = audio_features.shape[0]
|
||||||
|
|
||||||
|
if n_audio_tokens != n_audio_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
|
||||||
|
)
|
||||||
|
special_audio_mask = (input_ids == self.config.audio_token_index).to(inputs_embeds.device)
|
||||||
|
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
|
||||||
|
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
|||||||
@@ -40,15 +40,32 @@ class Qwen2AudioProcessor(ProcessorMixin):
|
|||||||
chat_template (`Optional[str]`, *optional*):
|
chat_template (`Optional[str]`, *optional*):
|
||||||
The Jinja template to use for formatting the conversation. If not provided, the default chat template
|
The Jinja template to use for formatting the conversation. If not provided, the default chat template
|
||||||
is used.
|
is used.
|
||||||
|
audio_token (`str`, *optional*, defaults to `"<|AUDIO|>"`):
|
||||||
|
The token to use for audio tokens.
|
||||||
|
audio_bos_token (`str`, *optional*, defaults to `"<|audio_bos|>"`):
|
||||||
|
The token to use for audio bos tokens.
|
||||||
|
audio_eos_token (`str`, *optional*, defaults to `"<|audio_eos|>"`):
|
||||||
|
The token to use for audio eos tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["feature_extractor", "tokenizer"]
|
attributes = ["feature_extractor", "tokenizer"]
|
||||||
feature_extractor_class = "WhisperFeatureExtractor"
|
feature_extractor_class = "WhisperFeatureExtractor"
|
||||||
tokenizer_class = "AutoTokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_extractor=None,
|
||||||
|
tokenizer=None,
|
||||||
|
chat_template=None,
|
||||||
|
audio_token="<|AUDIO|>",
|
||||||
|
audio_bos_token="<|audio_bos|>",
|
||||||
|
audio_eos_token="<|audio_eos|>",
|
||||||
|
):
|
||||||
if chat_template is None:
|
if chat_template is None:
|
||||||
chat_template = self.default_chat_template
|
chat_template = self.default_chat_template
|
||||||
|
self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
|
||||||
|
self.audio_bos_token = tokenizer.audio_bos_token if hasattr(tokenizer, "audio_bos_token") else audio_bos_token
|
||||||
|
self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token
|
||||||
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -88,7 +105,18 @@ class Qwen2AudioProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
if text is None:
|
if text is None:
|
||||||
raise ValueError("You need to specify either a `text` input to process.")
|
raise ValueError("You need to specify either a `text` input to process.")
|
||||||
inputs = self.tokenizer(text, padding=padding, **kwargs)
|
elif isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
|
# ensure we have as much audios as audio tokens
|
||||||
|
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
|
||||||
|
num_audios = 1 if type(audios) == np.ndarray else len(audios)
|
||||||
|
if num_audio_tokens != num_audios:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
if audios is not None:
|
if audios is not None:
|
||||||
audio_inputs = self.feature_extractor(
|
audio_inputs = self.feature_extractor(
|
||||||
@@ -97,6 +125,46 @@ class Qwen2AudioProcessor(ProcessorMixin):
|
|||||||
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
|
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
|
||||||
"attention_mask"
|
"attention_mask"
|
||||||
) # rename attention_mask to prevent conflicts later on
|
) # rename attention_mask to prevent conflicts later on
|
||||||
|
|
||||||
|
expanded_text = []
|
||||||
|
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()
|
||||||
|
|
||||||
|
for sample in text:
|
||||||
|
replace_str = []
|
||||||
|
while self.audio_token in sample:
|
||||||
|
audio_length = audio_lengths.pop(0)
|
||||||
|
input_length = (audio_length - 1) // 2 + 1
|
||||||
|
num_audio_tokens = (input_length - 2) // 2 + 1
|
||||||
|
|
||||||
|
expanded_audio_token = self.audio_token * num_audio_tokens
|
||||||
|
|
||||||
|
audio_token_start_idx = sample.find(self.audio_token)
|
||||||
|
audio_token_end_idx = audio_token_start_idx + len(self.audio_token)
|
||||||
|
|
||||||
|
has_bos = (
|
||||||
|
sample[audio_token_start_idx - len(self.audio_bos_token) : audio_token_start_idx]
|
||||||
|
== self.audio_bos_token
|
||||||
|
)
|
||||||
|
has_eos = (
|
||||||
|
sample[audio_token_end_idx : audio_token_end_idx + len(self.audio_eos_token)]
|
||||||
|
== self.audio_eos_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this audio token is surrounded by bos/eos tokens
|
||||||
|
if not has_bos and not has_eos:
|
||||||
|
expanded_audio_token = self.audio_bos_token + expanded_audio_token + self.audio_eos_token
|
||||||
|
|
||||||
|
replace_str.append(expanded_audio_token)
|
||||||
|
sample = sample.replace(self.audio_token, "<placeholder>", 1)
|
||||||
|
|
||||||
|
while "<placeholder>" in sample:
|
||||||
|
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
|
||||||
|
expanded_text.append(sample)
|
||||||
|
text = expanded_text
|
||||||
|
|
||||||
|
inputs = self.tokenizer(text, padding=padding, **kwargs)
|
||||||
|
|
||||||
|
if audios is not None:
|
||||||
inputs.update(audio_inputs)
|
inputs.update(audio_inputs)
|
||||||
|
|
||||||
return BatchFeature(data={**inputs})
|
return BatchFeature(data={**inputs})
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class Qwen2AudioModelTester:
|
|||||||
parent,
|
parent,
|
||||||
ignore_index=-100,
|
ignore_index=-100,
|
||||||
audio_token_index=0,
|
audio_token_index=0,
|
||||||
seq_length=7,
|
seq_length=25,
|
||||||
feat_seq_length=60,
|
feat_seq_length=60,
|
||||||
text_config={
|
text_config={
|
||||||
"model_type": "qwen2",
|
"model_type": "qwen2",
|
||||||
@@ -91,7 +91,7 @@ class Qwen2AudioModelTester:
|
|||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
|
|
||||||
self.batch_size = 3
|
self.batch_size = 3
|
||||||
self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1
|
self.encoder_seq_length = seq_length
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return Qwen2AudioConfig(
|
return Qwen2AudioConfig(
|
||||||
@@ -116,11 +116,13 @@ class Qwen2AudioModelTester:
|
|||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
config, input_features_values, feature_attention_mask = config_and_inputs
|
config, input_features_values, feature_attention_mask = config_and_inputs
|
||||||
|
input_length = (input_features_values.shape[-1] - 1) // 2 + 1
|
||||||
|
num_audio_tokens = (input_length - 2) // 2 + 1
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||||
attention_mask[:, :1] = 0
|
attention_mask[:, :1] = 0
|
||||||
# we are giving 3 audios let's make sure we pass in 3 audios tokens
|
# we are giving 3 audios let's make sure we pass in 3 audios tokens
|
||||||
input_ids[:, 1] = config.audio_token_index
|
input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"input_features": input_features_values,
|
"input_features": input_features_values,
|
||||||
"feature_attention_mask": feature_attention_mask,
|
"feature_attention_mask": feature_attention_mask,
|
||||||
@@ -237,54 +239,39 @@ class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=32)
|
output = model.generate(**inputs, max_new_tokens=32)
|
||||||
|
|
||||||
EXPECTED_INPUT_IDS = torch.tensor(
|
# fmt: off
|
||||||
[
|
EXPECTED_INPUT_IDS = torch.tensor([[
|
||||||
[
|
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
|
||||||
151644,
|
*[151646] * 101,
|
||||||
8948,
|
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
|
||||||
198,
|
]])
|
||||||
2610,
|
# fmt: on
|
||||||
525,
|
|
||||||
264,
|
|
||||||
10950,
|
|
||||||
17847,
|
|
||||||
13,
|
|
||||||
151645,
|
|
||||||
198,
|
|
||||||
151644,
|
|
||||||
872,
|
|
||||||
198,
|
|
||||||
14755,
|
|
||||||
220,
|
|
||||||
16,
|
|
||||||
25,
|
|
||||||
220,
|
|
||||||
151647,
|
|
||||||
151646,
|
|
||||||
151648,
|
|
||||||
198,
|
|
||||||
3838,
|
|
||||||
594,
|
|
||||||
429,
|
|
||||||
5112,
|
|
||||||
30,
|
|
||||||
151645,
|
|
||||||
198,
|
|
||||||
151644,
|
|
||||||
77091,
|
|
||||||
198,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
|
EXPECTED_DECODED_TEXT = (
|
||||||
|
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
|
||||||
|
+ "<|AUDIO|>" * 101
|
||||||
|
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.processor.decode(output[0], skip_special_tokens=False),
|
self.processor.decode(output[0], skip_special_tokens=False),
|
||||||
EXPECTED_DECODED_TEXT,
|
EXPECTED_DECODED_TEXT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# test the error when incorrect number of audio tokens
|
||||||
|
# fmt: off
|
||||||
|
inputs["input_ids"] = torch.tensor([[
|
||||||
|
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
|
||||||
|
*[151646] * 200,
|
||||||
|
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
|
||||||
|
]])
|
||||||
|
# fmt: on
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101"
|
||||||
|
):
|
||||||
|
model.generate(**inputs, max_new_tokens=32)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_small_model_integration_test_batch(self):
|
def test_small_model_integration_test_batch(self):
|
||||||
# Let' s make sure we test the preprocessing to replace what is used
|
# Let' s make sure we test the preprocessing to replace what is used
|
||||||
|
|||||||
Reference in New Issue
Block a user