add sdpa mbart (#32033)
* add sdpa mbart useful for donut * update sdpa docs * formatting * add self._use_sdpa in mbartencoder * use self.config to check attn * retrigger checks * [run-slow] mbart
This commit is contained in:
@@ -239,6 +239,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
|
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
|
||||||
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
||||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||||
|
* [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||||
|
|||||||
@@ -24,7 +24,12 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_attention_mask,
|
||||||
|
_prepare_4d_attention_mask_for_sdpa,
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@@ -405,8 +410,116 @@ class MBartFlashAttention2(MBartAttention):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart
|
||||||
|
class MBartSdpaAttention(MBartAttention):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
if output_attentions or layer_head_mask is not None:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"MBartModel is using MBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||||
|
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states,
|
||||||
|
key_value_states=key_value_states,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
# get key, value proj
|
||||||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||||
|
# the provided `key_value_states` to support prefix tuning
|
||||||
|
if (
|
||||||
|
is_cross_attention
|
||||||
|
and past_key_value is not None
|
||||||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||||
|
):
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0]
|
||||||
|
value_states = past_key_value[1]
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz)
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||||
|
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
||||||
|
|
||||||
|
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||||
|
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.dropout if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||||
|
# partitioned across GPUs when using tensor-parallelism.
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
MBART_ATTENTION_CLASSES = {
|
MBART_ATTENTION_CLASSES = {
|
||||||
"eager": MBartAttention,
|
"eager": MBartAttention,
|
||||||
|
"sdpa": MBartSdpaAttention,
|
||||||
"flash_attention_2": MBartFlashAttention2,
|
"flash_attention_2": MBartFlashAttention2,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -632,6 +745,7 @@ class MBartPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
|
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
@@ -841,7 +955,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||||||
embed_dim,
|
embed_dim,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
self.config = config
|
||||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
@@ -929,9 +1043,13 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||||||
|
|
||||||
# expand attention_mask
|
# expand attention_mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if self._use_flash_attention_2:
|
|
||||||
attention_mask = attention_mask if 0 in attention_mask else None
|
attention_mask = attention_mask if 0 in attention_mask else None
|
||||||
|
elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions:
|
||||||
|
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
||||||
else:
|
else:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||||
@@ -1021,7 +1139,8 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
config.d_model,
|
config.d_model,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
self.config = config
|
||||||
|
|
||||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
@@ -1141,9 +1260,18 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if self._use_flash_attention_2:
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None:
|
||||||
|
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
@@ -1152,8 +1280,17 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
|
|
||||||
# expand encoder attention mask
|
# expand encoder attention mask
|
||||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||||
if self._use_flash_attention_2:
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||||
|
elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
|
||||||
|
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||||
|
encoder_attention_mask,
|
||||||
|
inputs_embeds.dtype,
|
||||||
|
tgt_len=input_shape[-1],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||||
|
|||||||
Reference in New Issue
Block a user