[qwen2 audio] remove redundant code and update docs (#36282)
This commit is contained in:
@@ -29,7 +29,7 @@ The Qwen2-Audio is the new model series of large audio-language models from the
|
|||||||
* voice chat: users can freely engage in voice interactions with Qwen2-Audio without text input
|
* voice chat: users can freely engage in voice interactions with Qwen2-Audio without text input
|
||||||
* audio analysis: users could provide audio and text instructions for analysis during the interaction
|
* audio analysis: users could provide audio and text instructions for analysis during the interaction
|
||||||
|
|
||||||
It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou.
|
It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou.
|
||||||
|
|
||||||
The abstract from the paper is the following:
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ for message in conversation:
|
|||||||
for ele in message["content"]:
|
for ele in message["content"]:
|
||||||
if ele["type"] == "audio":
|
if ele["type"] == "audio":
|
||||||
audios.append(librosa.load(
|
audios.append(librosa.load(
|
||||||
BytesIO(urlopen(ele['audio_url']).read()),
|
BytesIO(urlopen(ele['audio_url']).read()),
|
||||||
sr=processor.feature_extractor.sampling_rate)[0]
|
sr=processor.feature_extractor.sampling_rate)[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
|
|||||||
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", device_map="auto")
|
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", device_map="auto")
|
||||||
|
|
||||||
conversation = [
|
conversation = [
|
||||||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||||||
{"role": "user", "content": [
|
{"role": "user", "content": [
|
||||||
{"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
|
{"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
|
||||||
{"type": "text", "text": "What's that sound?"},
|
{"type": "text", "text": "What's that sound?"},
|
||||||
@@ -148,7 +148,7 @@ for message in conversation:
|
|||||||
if ele["type"] == "audio":
|
if ele["type"] == "audio":
|
||||||
audios.append(
|
audios.append(
|
||||||
librosa.load(
|
librosa.load(
|
||||||
BytesIO(urlopen(ele['audio_url']).read()),
|
BytesIO(urlopen(ele['audio_url']).read()),
|
||||||
sr=processor.feature_extractor.sampling_rate)[0]
|
sr=processor.feature_extractor.sampling_rate)[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -203,7 +203,7 @@ for conversation in conversations:
|
|||||||
if ele["type"] == "audio":
|
if ele["type"] == "audio":
|
||||||
audios.append(
|
audios.append(
|
||||||
librosa.load(
|
librosa.load(
|
||||||
BytesIO(urlopen(ele['audio_url']).read()),
|
BytesIO(urlopen(ele['audio_url']).read()),
|
||||||
sr=processor.feature_extractor.sampling_rate)[0]
|
sr=processor.feature_extractor.sampling_rate)[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -221,7 +221,7 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
|
|||||||
|
|
||||||
[[autodoc]] Qwen2AudioConfig
|
[[autodoc]] Qwen2AudioConfig
|
||||||
|
|
||||||
## Qwen2AudioConfig
|
## Qwen2AudioEncoderConfig
|
||||||
|
|
||||||
[[autodoc]] Qwen2AudioEncoderConfig
|
[[autodoc]] Qwen2AudioEncoderConfig
|
||||||
|
|
||||||
@@ -229,6 +229,11 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
|
|||||||
|
|
||||||
[[autodoc]] Qwen2AudioProcessor
|
[[autodoc]] Qwen2AudioProcessor
|
||||||
|
|
||||||
|
## Qwen2AudioEncoder
|
||||||
|
|
||||||
|
[[autodoc]] Qwen2AudioEncoder
|
||||||
|
- forward
|
||||||
|
|
||||||
## Qwen2AudioForConditionalGeneration
|
## Qwen2AudioForConditionalGeneration
|
||||||
|
|
||||||
[[autodoc]] Qwen2AudioForConditionalGeneration
|
[[autodoc]] Qwen2AudioForConditionalGeneration
|
||||||
|
|||||||
@@ -16,14 +16,14 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, EncoderDecoderCache, StaticCache
|
from ...cache_utils import Cache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
@@ -35,6 +35,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
from ..auto import AutoModel, AutoModelForCausalLM
|
from ..auto import AutoModel, AutoModelForCausalLM
|
||||||
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
|
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
|
||||||
|
|
||||||
@@ -58,12 +59,15 @@ class Qwen2AudioCausalLMOutputWithPast(ModelOutput):
|
|||||||
Language modeling loss (for next-token prediction).
|
Language modeling loss (for next-token prediction).
|
||||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
two sets of pre-computed hidden-states: key and values states in the self-attention blocks.
|
||||||
|
The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True`.
|
||||||
|
It is a [`~cache_utils.Cache`] instance.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those
|
||||||
`past_key_values` input) to speed up sequential decoding.
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||||
|
all `input_ids` of shape `(batch_size, sequence_length)`.
|
||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -81,16 +85,16 @@ class Qwen2AudioCausalLMOutputWithPast(ModelOutput):
|
|||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
past_key_values: Optional[Cache] = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attention_mask: Optional[torch.FloatTensor] = None
|
attention_mask: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperAttention with Whisper->Qwen2Audio
|
|
||||||
class Qwen2AudioAttention(nn.Module):
|
class Qwen2AudioAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
# Copied from transformers.models.whisper.modeling_whisper.WhisperAttention.__init__ with Whisper->Qwen2Audio
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
@@ -135,11 +139,14 @@ class Qwen2AudioAttention(nn.Module):
|
|||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
@deprecate_kwarg("key_value_states", version="4.52")
|
||||||
|
@deprecate_kwarg("past_key_value", version="4.52")
|
||||||
|
@deprecate_kwarg("cache_position", version="4.52")
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
@@ -147,38 +154,12 @@ class Qwen2AudioAttention(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
||||||
# for the decoder
|
|
||||||
is_cross_attention = key_value_states is not None
|
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
|
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
if past_key_value is not None:
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
|
||||||
if is_cross_attention:
|
|
||||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
|
||||||
past_key_value.is_updated[self.layer_idx] = True
|
|
||||||
past_key_value = past_key_value.cross_attention_cache
|
|
||||||
else:
|
|
||||||
past_key_value = past_key_value.self_attention_cache
|
|
||||||
|
|
||||||
# use key_value_states if cross attention
|
|
||||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
|
||||||
if is_cross_attention and past_key_value and is_updated:
|
|
||||||
# reuse k,v, cross_attentions
|
|
||||||
key_states = past_key_value.key_cache[self.layer_idx]
|
|
||||||
value_states = past_key_value.value_cache[self.layer_idx]
|
|
||||||
else:
|
|
||||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
|
||||||
if past_key_value is not None:
|
|
||||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
|
||||||
cache_position = cache_position if not is_cross_attention else None
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||||
|
|
||||||
@@ -212,10 +193,9 @@ class Qwen2AudioAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2 with Whisper->Qwen2Audio
|
|
||||||
class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||||
"""
|
"""
|
||||||
Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays
|
Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays
|
||||||
@@ -223,6 +203,7 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
|||||||
flash attention and deal with padding tokens in case the input contains any of them.
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2.__init__ with Whisper->Qwen2Audio
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@@ -231,57 +212,29 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
|||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
@deprecate_kwarg("key_value_states", version="4.52")
|
||||||
|
@deprecate_kwarg("past_key_value", version="4.52")
|
||||||
|
@deprecate_kwarg("cache_position", version="4.52")
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
|
||||||
raise ValueError(
|
|
||||||
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
|
|
||||||
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
|
|
||||||
)
|
|
||||||
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
|
raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
||||||
# for the decoder
|
|
||||||
is_cross_attention = key_value_states is not None
|
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
|
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
if past_key_value is not None:
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
|
||||||
if is_cross_attention:
|
|
||||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
|
||||||
past_key_value.is_updated[self.layer_idx] = True
|
|
||||||
past_key_value = past_key_value.cross_attention_cache
|
|
||||||
else:
|
|
||||||
past_key_value = past_key_value.self_attention_cache
|
|
||||||
|
|
||||||
# use key_value_states if cross attention
|
|
||||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
|
||||||
if is_cross_attention and past_key_value and is_updated:
|
|
||||||
# reuse k,v, cross_attentions
|
|
||||||
key_states = past_key_value.key_cache[self.layer_idx]
|
|
||||||
value_states = past_key_value.value_cache[self.layer_idx]
|
|
||||||
else:
|
|
||||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
|
||||||
if past_key_value is not None:
|
|
||||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
|
||||||
cache_position = cache_position if not is_cross_attention else None
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
||||||
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
||||||
@@ -335,16 +288,18 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
|||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperSdpaAttention with Whisper->Qwen2Audio
|
|
||||||
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
||||||
|
@deprecate_kwarg("key_value_states", version="4.52")
|
||||||
|
@deprecate_kwarg("past_key_value", version="4.52")
|
||||||
|
@deprecate_kwarg("cache_position", version="4.52")
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
@@ -359,46 +314,17 @@ class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
|||||||
)
|
)
|
||||||
return super().forward(
|
return super().forward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
key_value_states=key_value_states,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
layer_head_mask=layer_head_mask,
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
||||||
# for the decoder
|
|
||||||
is_cross_attention = key_value_states is not None
|
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
if past_key_value is not None:
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
|
||||||
if is_cross_attention:
|
|
||||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
|
||||||
past_key_value.is_updated[self.layer_idx] = True
|
|
||||||
past_key_value = past_key_value.cross_attention_cache
|
|
||||||
else:
|
|
||||||
past_key_value = past_key_value.self_attention_cache
|
|
||||||
|
|
||||||
# use key_value_states if cross attention
|
|
||||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
|
||||||
if is_cross_attention and past_key_value and is_updated:
|
|
||||||
# reuse k,v, cross_attentions
|
|
||||||
key_states = past_key_value.key_cache[self.layer_idx]
|
|
||||||
value_states = past_key_value.value_cache[self.layer_idx]
|
|
||||||
else:
|
|
||||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
|
||||||
if past_key_value is not None:
|
|
||||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
|
||||||
cache_position = cache_position if not is_cross_attention else None
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
||||||
)
|
|
||||||
|
|
||||||
causal_mask = attention_mask
|
causal_mask = attention_mask
|
||||||
if attention_mask is not None: # no matter the length, we just slice it
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
@@ -434,7 +360,7 @@ class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
|||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, None
|
||||||
|
|
||||||
|
|
||||||
QWEN2AUDIO_ATTENTION_CLASSES = {
|
QWEN2AUDIO_ATTENTION_CLASSES = {
|
||||||
@@ -815,16 +741,15 @@ QWEN2AUDIO_INPUTS_DOCSTRING = r"""
|
|||||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
two sets of pre-computed hidden-states: key and values states in the self-attention blocks.
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True`.
|
||||||
|
It is a [`~cache_utils.Cache`] instance.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those
|
||||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||||
|
all `input_ids` of shape `(batch_size, sequence_length)`.shape `(batch_size, 1)` instead of all
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
@@ -851,7 +776,7 @@ QWEN2AUDIO_INPUTS_DOCSTRING = r"""
|
|||||||
class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin):
|
class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin):
|
||||||
def __init__(self, config: Qwen2AudioConfig):
|
def __init__(self, config: Qwen2AudioConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.audio_tower = AutoModel.from_config(config.audio_config)
|
self.audio_tower = AutoModel.from_config(config.audio_config) # Usually a `Qwen2AudioEncoder` instance
|
||||||
|
|
||||||
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
|
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
|
||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
@@ -1103,7 +1028,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
feature_attention_mask: Optional[torch.Tensor] = None,
|
feature_attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -1258,78 +1183,5 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
input_features=None,
|
|
||||||
attention_mask=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# Overwritten -- custom processing (note: might not be needed, but there are no generation tests running atm)
|
|
||||||
|
|
||||||
if past_key_values is not None:
|
|
||||||
if isinstance(past_key_values, Cache):
|
|
||||||
cache_length = past_key_values.get_seq_length()
|
|
||||||
past_length = past_key_values.seen_tokens
|
|
||||||
else:
|
|
||||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
|
||||||
|
|
||||||
# Here, we get the attention_mask, which was previously stored in the state after _merge_input_ids_with_audio_features.
|
|
||||||
if input_features is not None and kwargs.get("attention_mask") is not None:
|
|
||||||
attention_mask = kwargs["attention_mask"]
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
|
||||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
|
||||||
# input)
|
|
||||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
|
||||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
|
||||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
|
||||||
# input_ids based on the past_length.
|
|
||||||
elif past_length < input_ids.shape[1]:
|
|
||||||
input_ids = input_ids[:, past_length:]
|
|
||||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
||||||
elif self.config.audio_token_index in input_ids:
|
|
||||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
|
||||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
|
||||||
# older attention values, as their corresponding values are not part of the input.
|
|
||||||
if cache_length < past_length and attention_mask is not None:
|
|
||||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
if attention_mask is not None and position_ids is None:
|
|
||||||
# create position_ids on the fly for batch generation
|
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
||||||
if past_key_values:
|
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
||||||
else:
|
|
||||||
model_inputs = {"input_ids": input_ids}
|
|
||||||
|
|
||||||
feature_attention_mask = kwargs.get("feature_attention_mask", None)
|
|
||||||
model_inputs.update(
|
|
||||||
{
|
|
||||||
"position_ids": position_ids,
|
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"use_cache": kwargs.get("use_cache"),
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"input_features": input_features,
|
|
||||||
"feature_attention_mask": feature_attention_mask,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def _reorder_cache(self, *args, **kwargs):
|
|
||||||
return self.language_model._reorder_cache(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]
|
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]
|
||||||
|
|||||||
Reference in New Issue
Block a user