Add Flash Attention 2 support to Musicgen and Musicgen Melody (#29939)

* add FA2 to o.g Musicgen

* make style

* add FA2 support to Musicgen Melody

* add generation FA2 tests to o.g Musicgen

* make style and fix copies

* add Musicgen to FA2 docs + deprecate list

* add sdpa supports to Musicgen's

* make style and fix copies

* refactor attention implementation arguments

* add Copied from to sdpa tests

* add copied form in sdpa tests melody

* add copied for FA2 generation tests

* add FA2 inference copied from

* make style
This commit is contained in:
Yoach Lacombe
2024-04-02 11:23:49 +01:00
committed by GitHub
parent fed27ffc7e
commit 0d04b1e25a
8 changed files with 3313 additions and 24 deletions

View File

@@ -55,6 +55,8 @@ FlashAttention-2 is currently supported for the following architectures:
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
@@ -190,6 +192,8 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
<Tip>

View File

@@ -1470,6 +1470,12 @@ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-small"])
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
{"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json"}
)
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-melody"])
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(
[
"RUCAIBox/mvp",

View File

@@ -239,3 +239,20 @@ class MusicgenConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value

View File

@@ -22,13 +22,19 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -40,6 +46,8 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -48,6 +56,10 @@ from ..auto.modeling_auto import AutoModel
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
@@ -60,6 +72,19 @@ _CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
from ..deprecated._archive_maps import MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
@dataclass
class MusicgenUnconditionalInput(ModelOutput):
"""
@@ -302,29 +327,361 @@ class MusicgenAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen
class MusicgenFlashAttention2(MusicgenAttention):
"""
Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# MusicgenFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len, _ = hidden_states.size()
# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
class MusicgenSdpaAttention(MusicgenAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
MUSICGEN_ATTENTION_CLASSES = {
"eager": MusicgenAttention,
"sdpa": MusicgenSdpaAttention,
"flash_attention_2": MusicgenFlashAttention2,
}
class MusicgenDecoderLayer(nn.Module):
def __init__(self, config: MusicgenDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MusicgenAttention(
self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MusicgenAttention(
self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
@@ -432,6 +789,8 @@ class MusicgenPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_factor
@@ -667,6 +1026,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.attn_implementation = config._attn_implementation
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -721,16 +1081,40 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
if self.attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
if self.attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions
positions = self.embed_positions(input, past_key_values_length)
@@ -1409,6 +1793,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,

View File

@@ -21,9 +21,7 @@ from ..auto.configuration_auto import AutoConfig
logger = logging.get_logger(__name__)
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json",
}
from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
class MusicgenMelodyDecoderConfig(PretrainedConfig):
@@ -254,3 +252,20 @@ class MusicgenMelodyConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value

View File

@@ -22,13 +22,14 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutputWithPast,
ModelOutput,
@@ -37,6 +38,8 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -45,6 +48,10 @@ from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding
from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
@@ -53,10 +60,20 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MusicgenMelodyConfig"
_CHECKPOINT_FOR_DOC = "facebook/musicgen-melody"
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/musicgen-melody",
# See all Musicgen Melody models at https://huggingface.co/models?filter=musicgen_melody
]
from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
@dataclass
@@ -324,17 +341,348 @@ class MusicgenMelodyAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MusicgenMelody
class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
"""
MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# MusicgenMelodyFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("MusicgenMelodyFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len, _ = hidden_states.size()
# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MusicgenMelody
class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
MUSICGEN_MELODY_ATTENTION_CLASSES = {
"eager": MusicgenMelodyAttention,
"sdpa": MusicgenMelodySdpaAttention,
"flash_attention_2": MusicgenMelodyFlashAttention2,
}
class MusicgenMelodyDecoderLayer(nn.Module):
def __init__(self, config: MusicgenMelodyDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MusicgenMelodyAttention(
self.self_attn = MUSICGEN_MELODY_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
@@ -414,6 +762,8 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_factor
@@ -626,6 +976,7 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
self.layers = nn.ModuleList([MusicgenMelodyDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.attn_implementation = config._attn_implementation
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -695,9 +1046,21 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
if self.attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# embed positions
positions = self.embed_positions(inputs_embeds, past_key_values_length)
@@ -1373,6 +1736,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
config_class = MusicgenMelodyConfig
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff