Add Flash Attention 2 support to Musicgen and Musicgen Melody (#29939)
* add FA2 to o.g Musicgen * make style * add FA2 support to Musicgen Melody * add generation FA2 tests to o.g Musicgen * make style and fix copies * add Musicgen to FA2 docs + deprecate list * add sdpa supports to Musicgen's * make style and fix copies * refactor attention implementation arguments * add Copied from to sdpa tests * add copied form in sdpa tests melody * add copied for FA2 generation tests * add FA2 inference copied from * make style
This commit is contained in:
@@ -55,6 +55,8 @@ FlashAttention-2 is currently supported for the following architectures:
|
|||||||
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||||
|
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||||
|
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||||
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
|
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
|
||||||
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
||||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||||
@@ -190,6 +192,8 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
||||||
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
||||||
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
|
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
|
||||||
|
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||||
|
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
|
|||||||
@@ -1470,6 +1470,12 @@ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
|
|||||||
|
|
||||||
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-small"])
|
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-small"])
|
||||||
|
|
||||||
|
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
|
||||||
|
{"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json"}
|
||||||
|
)
|
||||||
|
|
||||||
|
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-melody"])
|
||||||
|
|
||||||
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(
|
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(
|
||||||
[
|
[
|
||||||
"RUCAIBox/mvp",
|
"RUCAIBox/mvp",
|
||||||
|
|||||||
@@ -239,3 +239,20 @@ class MusicgenConfig(PretrainedConfig):
|
|||||||
# This is a property because you might want to change the codec model on the fly
|
# This is a property because you might want to change the codec model on the fly
|
||||||
def sampling_rate(self):
|
def sampling_rate(self):
|
||||||
return self.audio_encoder.sampling_rate
|
return self.audio_encoder.sampling_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _attn_implementation(self):
|
||||||
|
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
|
||||||
|
if hasattr(self, "_attn_implementation_internal"):
|
||||||
|
if self._attn_implementation_internal is None:
|
||||||
|
# `config.attn_implementation` should never be None, for backward compatibility.
|
||||||
|
return "eager"
|
||||||
|
else:
|
||||||
|
return self._attn_implementation_internal
|
||||||
|
else:
|
||||||
|
return "eager"
|
||||||
|
|
||||||
|
@_attn_implementation.setter
|
||||||
|
def _attn_implementation(self, value):
|
||||||
|
self._attn_implementation_internal = value
|
||||||
|
self.decoder._attn_implementation = value
|
||||||
|
|||||||
@@ -22,13 +22,19 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation.configuration_utils import GenerationConfig
|
from ...generation.configuration_utils import GenerationConfig
|
||||||
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
|
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
|
||||||
from ...generation.stopping_criteria import StoppingCriteriaList
|
from ...generation.stopping_criteria import StoppingCriteriaList
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_attention_mask,
|
||||||
|
_prepare_4d_attention_mask_for_sdpa,
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@@ -40,6 +46,8 @@ from ...modeling_utils import PreTrainedModel
|
|||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_flash_attn_2_available,
|
||||||
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@@ -48,6 +56,10 @@ from ..auto.modeling_auto import AutoModel
|
|||||||
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
|
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ...generation.streamers import BaseStreamer
|
from ...generation.streamers import BaseStreamer
|
||||||
|
|
||||||
@@ -60,6 +72,19 @@ _CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
|
|||||||
from ..deprecated._archive_maps import MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
from ..deprecated._archive_maps import MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||||
|
def _get_unpad_data(attention_mask):
|
||||||
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
return (
|
||||||
|
indices,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MusicgenUnconditionalInput(ModelOutput):
|
class MusicgenUnconditionalInput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -302,29 +327,361 @@ class MusicgenAttention(nn.Module):
|
|||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen
|
||||||
|
class MusicgenFlashAttention2(MusicgenAttention):
|
||||||
|
"""
|
||||||
|
Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
|
||||||
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# MusicgenFlashAttention2 attention does not support output_attentions
|
||||||
|
if output_attentions:
|
||||||
|
raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions")
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
||||||
|
# get key, value proj
|
||||||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||||
|
# the provided `key_value_states` to support prefix tuning
|
||||||
|
if (
|
||||||
|
is_cross_attention
|
||||||
|
and past_key_value is not None
|
||||||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||||
|
):
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0].transpose(1, 2)
|
||||||
|
value_states = past_key_value[1].transpose(1, 2)
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
||||||
|
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||||
|
def _flash_attention_forward(
|
||||||
|
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||||
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||||
|
Args:
|
||||||
|
query_states (`torch.Tensor`):
|
||||||
|
Input query states to be passed to Flash Attention API
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
Input key states to be passed to Flash Attention API
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
Input value states to be passed to Flash Attention API
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||||
|
position of padding tokens and 1 for the position of non-padding tokens.
|
||||||
|
dropout (`float`):
|
||||||
|
Attention dropout
|
||||||
|
softmax_scale (`float`, *optional*):
|
||||||
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||||
|
"""
|
||||||
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
|
causal = self.is_causal
|
||||||
|
else:
|
||||||
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||||
|
causal = self.is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
if attention_mask is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||||
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(
|
||||||
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(
|
||||||
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||||
|
) # There is a memcpy here, that is very bad.
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
|
||||||
|
class MusicgenSdpaAttention(MusicgenAttention):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
if output_attentions or layer_head_mask is not None:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||||
|
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states,
|
||||||
|
key_value_states=key_value_states,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
# get key, value proj
|
||||||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||||
|
# the provided `key_value_states` to support prefix tuning
|
||||||
|
if (
|
||||||
|
is_cross_attention
|
||||||
|
and past_key_value is not None
|
||||||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||||
|
):
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0]
|
||||||
|
value_states = past_key_value[1]
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz)
|
||||||
|
|
||||||
|
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||||
|
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.dropout if self.training else 0.0,
|
||||||
|
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||||
|
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||||
|
# partitioned across GPUs when using tensor-parallelism.
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
MUSICGEN_ATTENTION_CLASSES = {
|
||||||
|
"eager": MusicgenAttention,
|
||||||
|
"sdpa": MusicgenSdpaAttention,
|
||||||
|
"flash_attention_2": MusicgenFlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MusicgenDecoderLayer(nn.Module):
|
class MusicgenDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: MusicgenDecoderConfig):
|
def __init__(self, config: MusicgenDecoderConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = MusicgenAttention(
|
self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
embed_dim=self.embed_dim,
|
embed_dim=self.embed_dim,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
dropout=config.attention_dropout,
|
dropout=config.attention_dropout,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
is_causal=True,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
self.activation_fn = ACT2FN[config.activation_function]
|
||||||
self.activation_dropout = config.activation_dropout
|
self.activation_dropout = config.activation_dropout
|
||||||
|
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
self.encoder_attn = MusicgenAttention(
|
self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
dropout=config.attention_dropout,
|
dropout=config.attention_dropout,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
|
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
|
||||||
@@ -432,6 +789,8 @@ class MusicgenPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
|
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_factor
|
std = self.config.initializer_factor
|
||||||
@@ -667,6 +1026,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
|
|||||||
|
|
||||||
self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
self.attn_implementation = config._attn_implementation
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@@ -721,16 +1081,40 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
|
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
|
||||||
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
if self.attn_implementation == "flash_attention_2":
|
||||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
)
|
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
|
||||||
|
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
if self.attn_implementation == "flash_attention_2":
|
||||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
|
||||||
)
|
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||||
|
encoder_attention_mask,
|
||||||
|
inputs_embeds.dtype,
|
||||||
|
tgt_len=input_shape[-1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||||
|
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(input, past_key_values_length)
|
positions = self.embed_positions(input, past_key_values_length)
|
||||||
@@ -1409,6 +1793,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
base_model_prefix = "encoder_decoder"
|
base_model_prefix = "encoder_decoder"
|
||||||
main_input_name = "input_ids"
|
main_input_name = "input_ids"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,9 +21,7 @@ from ..auto.configuration_auto import AutoConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
|
||||||
"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MusicgenMelodyDecoderConfig(PretrainedConfig):
|
class MusicgenMelodyDecoderConfig(PretrainedConfig):
|
||||||
@@ -254,3 +252,20 @@ class MusicgenMelodyConfig(PretrainedConfig):
|
|||||||
# This is a property because you might want to change the codec model on the fly
|
# This is a property because you might want to change the codec model on the fly
|
||||||
def sampling_rate(self):
|
def sampling_rate(self):
|
||||||
return self.audio_encoder.sampling_rate
|
return self.audio_encoder.sampling_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _attn_implementation(self):
|
||||||
|
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
|
||||||
|
if hasattr(self, "_attn_implementation_internal"):
|
||||||
|
if self._attn_implementation_internal is None:
|
||||||
|
# `config.attn_implementation` should never be None, for backward compatibility.
|
||||||
|
return "eager"
|
||||||
|
else:
|
||||||
|
return self._attn_implementation_internal
|
||||||
|
else:
|
||||||
|
return "eager"
|
||||||
|
|
||||||
|
@_attn_implementation.setter
|
||||||
|
def _attn_implementation(self, value):
|
||||||
|
self._attn_implementation_internal = value
|
||||||
|
self.decoder._attn_implementation = value
|
||||||
|
|||||||
@@ -22,13 +22,14 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation.configuration_utils import GenerationConfig
|
from ...generation.configuration_utils import GenerationConfig
|
||||||
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
|
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
|
||||||
from ...generation.stopping_criteria import StoppingCriteriaList
|
from ...generation.stopping_criteria import StoppingCriteriaList
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -37,6 +38,8 @@ from ...modeling_utils import PreTrainedModel
|
|||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_flash_attn_2_available,
|
||||||
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@@ -45,6 +48,10 @@ from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding
|
|||||||
from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig
|
from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ...generation.streamers import BaseStreamer
|
from ...generation.streamers import BaseStreamer
|
||||||
|
|
||||||
@@ -53,10 +60,20 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "MusicgenMelodyConfig"
|
_CONFIG_FOR_DOC = "MusicgenMelodyConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/musicgen-melody"
|
_CHECKPOINT_FOR_DOC = "facebook/musicgen-melody"
|
||||||
|
|
||||||
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||||||
"facebook/musicgen-melody",
|
|
||||||
# See all Musicgen Melody models at https://huggingface.co/models?filter=musicgen_melody
|
|
||||||
]
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||||
|
def _get_unpad_data(attention_mask):
|
||||||
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
return (
|
||||||
|
indices,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -324,17 +341,348 @@ class MusicgenMelodyAttention(nn.Module):
|
|||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MusicgenMelody
|
||||||
|
class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
|
||||||
|
"""
|
||||||
|
MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays
|
||||||
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# MusicgenMelodyFlashAttention2 attention does not support output_attentions
|
||||||
|
if output_attentions:
|
||||||
|
raise ValueError("MusicgenMelodyFlashAttention2 attention does not support output_attentions")
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
||||||
|
# get key, value proj
|
||||||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||||
|
# the provided `key_value_states` to support prefix tuning
|
||||||
|
if (
|
||||||
|
is_cross_attention
|
||||||
|
and past_key_value is not None
|
||||||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||||
|
):
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0].transpose(1, 2)
|
||||||
|
value_states = past_key_value[1].transpose(1, 2)
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
||||||
|
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||||
|
def _flash_attention_forward(
|
||||||
|
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||||
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||||
|
Args:
|
||||||
|
query_states (`torch.Tensor`):
|
||||||
|
Input query states to be passed to Flash Attention API
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
Input key states to be passed to Flash Attention API
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
Input value states to be passed to Flash Attention API
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||||
|
position of padding tokens and 1 for the position of non-padding tokens.
|
||||||
|
dropout (`float`):
|
||||||
|
Attention dropout
|
||||||
|
softmax_scale (`float`, *optional*):
|
||||||
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||||
|
"""
|
||||||
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
|
causal = self.is_causal
|
||||||
|
else:
|
||||||
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||||
|
causal = self.is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
if attention_mask is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||||
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(
|
||||||
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(
|
||||||
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||||
|
) # There is a memcpy here, that is very bad.
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MusicgenMelody
|
||||||
|
class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
if output_attentions or layer_head_mask is not None:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||||
|
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states,
|
||||||
|
key_value_states=key_value_states,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
# get key, value proj
|
||||||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||||
|
# the provided `key_value_states` to support prefix tuning
|
||||||
|
if (
|
||||||
|
is_cross_attention
|
||||||
|
and past_key_value is not None
|
||||||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||||
|
):
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0]
|
||||||
|
value_states = past_key_value[1]
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz)
|
||||||
|
|
||||||
|
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||||
|
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.dropout if self.training else 0.0,
|
||||||
|
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||||
|
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||||
|
# partitioned across GPUs when using tensor-parallelism.
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
MUSICGEN_MELODY_ATTENTION_CLASSES = {
|
||||||
|
"eager": MusicgenMelodyAttention,
|
||||||
|
"sdpa": MusicgenMelodySdpaAttention,
|
||||||
|
"flash_attention_2": MusicgenMelodyFlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MusicgenMelodyDecoderLayer(nn.Module):
|
class MusicgenMelodyDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: MusicgenMelodyDecoderConfig):
|
def __init__(self, config: MusicgenMelodyDecoderConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = MusicgenMelodyAttention(
|
self.self_attn = MUSICGEN_MELODY_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
embed_dim=self.embed_dim,
|
embed_dim=self.embed_dim,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
dropout=config.attention_dropout,
|
dropout=config.attention_dropout,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
is_causal=True,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
self.activation_fn = ACT2FN[config.activation_function]
|
||||||
@@ -414,6 +762,8 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
|
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_factor
|
std = self.config.initializer_factor
|
||||||
@@ -626,6 +976,7 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
|
|||||||
|
|
||||||
self.layers = nn.ModuleList([MusicgenMelodyDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([MusicgenMelodyDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
self.attn_implementation = config._attn_implementation
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@@ -695,9 +1046,21 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
|
|||||||
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
if self.attn_implementation == "flash_attention_2":
|
||||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
)
|
elif self.attn_implementation == "sdpa" and not output_attentions:
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(inputs_embeds, past_key_values_length)
|
positions = self.embed_positions(inputs_embeds, past_key_values_length)
|
||||||
@@ -1373,6 +1736,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
config_class = MusicgenMelodyConfig
|
config_class = MusicgenMelodyConfig
|
||||||
main_input_name = "input_ids"
|
main_input_name = "input_ids"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user