Add Flash Attention 2 support to Musicgen and Musicgen Melody (#29939)
* add FA2 to o.g Musicgen * make style * add FA2 support to Musicgen Melody * add generation FA2 tests to o.g Musicgen * make style and fix copies * add Musicgen to FA2 docs + deprecate list * add sdpa supports to Musicgen's * make style and fix copies * refactor attention implementation arguments * add Copied from to sdpa tests * add copied form in sdpa tests melody * add copied for FA2 generation tests * add FA2 inference copied from * make style
This commit is contained in:
@@ -55,6 +55,8 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
|
||||
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||
@@ -190,6 +192,8 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
||||
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
||||
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||
|
||||
<Tip>
|
||||
|
||||
|
||||
@@ -1470,6 +1470,12 @@ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
|
||||
|
||||
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-small"])
|
||||
|
||||
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
|
||||
{"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json"}
|
||||
)
|
||||
|
||||
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-melody"])
|
||||
|
||||
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(
|
||||
[
|
||||
"RUCAIBox/mvp",
|
||||
|
||||
@@ -239,3 +239,20 @@ class MusicgenConfig(PretrainedConfig):
|
||||
# This is a property because you might want to change the codec model on the fly
|
||||
def sampling_rate(self):
|
||||
return self.audio_encoder.sampling_rate
|
||||
|
||||
@property
|
||||
def _attn_implementation(self):
|
||||
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
|
||||
if hasattr(self, "_attn_implementation_internal"):
|
||||
if self._attn_implementation_internal is None:
|
||||
# `config.attn_implementation` should never be None, for backward compatibility.
|
||||
return "eager"
|
||||
else:
|
||||
return self._attn_implementation_internal
|
||||
else:
|
||||
return "eager"
|
||||
|
||||
@_attn_implementation.setter
|
||||
def _attn_implementation(self, value):
|
||||
self._attn_implementation_internal = value
|
||||
self.decoder._attn_implementation = value
|
||||
|
||||
@@ -22,13 +22,19 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation.configuration_utils import GenerationConfig
|
||||
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
|
||||
from ...generation.stopping_criteria import StoppingCriteriaList
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@@ -40,6 +46,8 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@@ -48,6 +56,10 @@ from ..auto.modeling_auto import AutoModel
|
||||
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...generation.streamers import BaseStreamer
|
||||
|
||||
@@ -60,6 +72,19 @@ _CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
|
||||
from ..deprecated._archive_maps import MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MusicgenUnconditionalInput(ModelOutput):
|
||||
"""
|
||||
@@ -302,29 +327,361 @@ class MusicgenAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen
|
||||
class MusicgenFlashAttention2(MusicgenAttention):
|
||||
"""
|
||||
Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# MusicgenFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0].transpose(1, 2)
|
||||
value_states = past_key_value[1].transpose(1, 2)
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
||||
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||
def _flash_attention_forward(
|
||||
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`float`):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
"""
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
|
||||
class MusicgenSdpaAttention(MusicgenAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
MUSICGEN_ATTENTION_CLASSES = {
|
||||
"eager": MusicgenAttention,
|
||||
"sdpa": MusicgenSdpaAttention,
|
||||
"flash_attention_2": MusicgenFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class MusicgenDecoderLayer(nn.Module):
|
||||
def __init__(self, config: MusicgenDecoderConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.self_attn = MusicgenAttention(
|
||||
self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
bias=False,
|
||||
is_causal=True,
|
||||
config=config,
|
||||
)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = MusicgenAttention(
|
||||
self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.num_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
bias=False,
|
||||
config=config,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
|
||||
@@ -432,6 +789,8 @@ class MusicgenPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_factor
|
||||
@@ -667,6 +1026,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
|
||||
|
||||
self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.attn_implementation = config._attn_implementation
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@@ -721,16 +1081,40 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
if self.attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
if self.attn_implementation == "flash_attention_2":
|
||||
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||
elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask,
|
||||
inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1],
|
||||
)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input, past_key_values_length)
|
||||
@@ -1409,6 +1793,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
base_model_prefix = "encoder_decoder"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -21,9 +21,7 @@ from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json",
|
||||
}
|
||||
from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
|
||||
|
||||
|
||||
class MusicgenMelodyDecoderConfig(PretrainedConfig):
|
||||
@@ -254,3 +252,20 @@ class MusicgenMelodyConfig(PretrainedConfig):
|
||||
# This is a property because you might want to change the codec model on the fly
|
||||
def sampling_rate(self):
|
||||
return self.audio_encoder.sampling_rate
|
||||
|
||||
@property
|
||||
def _attn_implementation(self):
|
||||
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
|
||||
if hasattr(self, "_attn_implementation_internal"):
|
||||
if self._attn_implementation_internal is None:
|
||||
# `config.attn_implementation` should never be None, for backward compatibility.
|
||||
return "eager"
|
||||
else:
|
||||
return self._attn_implementation_internal
|
||||
else:
|
||||
return "eager"
|
||||
|
||||
@_attn_implementation.setter
|
||||
def _attn_implementation(self, value):
|
||||
self._attn_implementation_internal = value
|
||||
self.decoder._attn_implementation = value
|
||||
|
||||
@@ -22,13 +22,14 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation.configuration_utils import GenerationConfig
|
||||
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
|
||||
from ...generation.stopping_criteria import StoppingCriteriaList
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
ModelOutput,
|
||||
@@ -37,6 +38,8 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@@ -45,6 +48,10 @@ from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding
|
||||
from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...generation.streamers import BaseStreamer
|
||||
|
||||
@@ -53,10 +60,20 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "MusicgenMelodyConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/musicgen-melody"
|
||||
|
||||
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/musicgen-melody",
|
||||
# See all Musicgen Melody models at https://huggingface.co/models?filter=musicgen_melody
|
||||
]
|
||||
from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -324,17 +341,348 @@ class MusicgenMelodyAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MusicgenMelody
|
||||
class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
|
||||
"""
|
||||
MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# MusicgenMelodyFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("MusicgenMelodyFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0].transpose(1, 2)
|
||||
value_states = past_key_value[1].transpose(1, 2)
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
||||
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||
def _flash_attention_forward(
|
||||
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`float`):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
"""
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MusicgenMelody
|
||||
class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
MUSICGEN_MELODY_ATTENTION_CLASSES = {
|
||||
"eager": MusicgenMelodyAttention,
|
||||
"sdpa": MusicgenMelodySdpaAttention,
|
||||
"flash_attention_2": MusicgenMelodyFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class MusicgenMelodyDecoderLayer(nn.Module):
|
||||
def __init__(self, config: MusicgenMelodyDecoderConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.self_attn = MusicgenMelodyAttention(
|
||||
self.self_attn = MUSICGEN_MELODY_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
bias=False,
|
||||
is_causal=True,
|
||||
config=config,
|
||||
)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
@@ -414,6 +762,8 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_factor
|
||||
@@ -626,6 +976,7 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
|
||||
|
||||
self.layers = nn.ModuleList([MusicgenMelodyDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.attn_implementation = config._attn_implementation
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@@ -695,9 +1046,21 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
|
||||
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
if self.attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self.attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length)
|
||||
@@ -1373,6 +1736,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
config_class = MusicgenMelodyConfig
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user