Kernels flash attn (#39474)

* use partial to wrap around `transformers` utils!

* try to refactor?

* revert one wrong change

* just a nit

* push

* reverter watever was wrong!

* some nits

* fixes when there is no attention mask

* bring the licence back

* some fixes

* nit

* style

* remove prints

* correct dtype

* fa flags for testing

* update

* use paged attention if requested!

* updates

* a clone was needed, not sure why

* automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute

* simplify and improve?

* flash attention is kinda broken on recent cuda version so allow the opportunity to use something else

* fix!

* protect kernels import

* update

* properly parse generation config being passed

* revert and update

* add two tests

* some fixes

* fix test FA2

* takes comment into account

* fixup

* revert changes

* revert the clone, it is only needed because the metal kernel is not doing it?

* [docs] update attention implementation and cache docs (#39547)

* update docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* applu suggestions

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix mps on our side for now

* Update src/transformers/integrations/flash_paged.py

* no qa

---------

Co-authored-by: Vasqu <antonprogamer@gmail.com>
Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Arthur
2025-07-22 15:41:06 +02:00
committed by GitHub
parent b62557e712
commit efceeaf267
9 changed files with 336 additions and 421 deletions

View File

@@ -1119,7 +1119,8 @@ class ContinuousBatchingManager:
self._request_lock = threading.Lock()
self.model.generation_config.top_p = None
self.do_sample = getattr(generation_config, "do_sample", True)
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
generation_config = model.generation_config if generation_config is None else generation_config
self.logit_processor = self.model._get_logits_processor(generation_config)
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction

View File

@@ -677,6 +677,24 @@ class GenerationMixin(ContinuousMixin):
if encoder_attention_mask is not None:
model_inputs["attention_mask"] = encoder_attention_mask
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
tensor_kws = {"dtype": torch.int32, "device": self.device}
pos = model_inputs["position_ids"][:, -1]
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
max_length_k = int(pos.max()) + 1
bs, seq_len = input_ids.size()
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
max_length_q = int(q_len.max())
model_inputs.update(
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
max_length_q=max_length_q,
max_length_k=max_length_k,
)
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:

View File

@@ -38,7 +38,6 @@ def flash_attention_forward(
"FlashAttention does not support inputs with dim=0.\n"
"Please check your input shapes or use SDPA instead."
)
# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
@@ -76,6 +75,7 @@ def flash_attention_forward(
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
**kwargs,
)

View File

@@ -5,7 +5,7 @@ from ..utils import is_flash_attn_2_available
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401
def paged_attention_forward(
@@ -20,6 +20,7 @@ def paged_attention_forward(
max_seqlen_q=None,
max_seqlen_k=None,
block_tables=None,
implementation=None,
**kwargs,
) -> torch.Tensor:
r"""Perform the forward pass of attention with paged key-value cache.
@@ -46,12 +47,14 @@ def paged_attention_forward(
"""
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
if implementation is not None:
flash_attn_varlen_func = implementation.flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
q.transpose(1, 2).squeeze(0),
k.transpose(1, 2).squeeze(0),
v.transpose(1, 2).squeeze(0),
q.transpose(1, 2).squeeze(0).contiguous(),
k.transpose(1, 2).squeeze(0).contiguous(),
v.transpose(1, 2).squeeze(0).contiguous(),
cumulative_seqlens_q.to(torch.int32),
cumulative_seqlens_k.to(torch.int32),
cumulative_seqlens_k.to(torch.int32).clone(),
max_seqlen_q,
max_seqlen_k,
softmax_scale=module.scaling,

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import warnings
@@ -20,6 +19,8 @@ from typing import Optional, TypedDict
import torch
import torch.nn.functional as F
from transformers.utils.import_utils import is_kernels_available
from .utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
@@ -31,25 +32,16 @@ from .utils import (
logger = logging.get_logger(__name__)
flash_attn_func = None
def _index_first_axis(tensor, indices):
"""
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
after flattening the first two dimensions of the tensor. This is functionally equivalent to
FA2's `index_first_axis` and replaces the need to import it.
"""
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
# two dimensions to get (total_tokens, ...) before indexing.
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
return reshaped_tensor[indices]
def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:])
return reshaped[indices]
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
FA3-compatible unpad_input function.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
@@ -80,7 +72,6 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
"""
FA3-compatible pad_input function.
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
@@ -95,109 +86,12 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
return output.view(batch, seqlen, *dim)
FA_VERSION = None
if is_flash_attn_2_available():
from flash_attn import flash_attn_func as flash_attn_2_func
from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func
from flash_attn.bert_padding import pad_input as pad_input_fa2
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
from flash_attn.layers.rotary import apply_rotary_emb
HAS_FA2 = True
FA_VERSION = 2
elif is_torch_npu_available():
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func
from .integrations.npu_flash_attention import pad_input as pad_input_fa2
from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2
HAS_FA2 = True
FA_VERSION = 2
else:
flash_attn_2_func = None
flash_attn_2_varlen_func = None
pad_input_fa2 = None
unpad_input_fa2 = None
apply_rotary_emb = None
HAS_FA2 = False
if is_flash_attn_3_available():
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
pad_input_fa3 = _fa3_pad_input
unpad_input_fa3 = _fa3_unpad_input
HAS_FA3 = True
FA_VERSION = 3
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
pad_input_fa3 = None
unpad_input_fa3 = None
HAS_FA3 = False
# Current Flash Attention implementations
if FA_VERSION:
flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"]
flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"]
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]
_flash_supports_window_size = False
if flash_attn_func:
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
def is_flash_attn_available():
"""Determine whether flash-attention can be used or not."""
if is_flash_attn_3_available():
return True
# if package `flash-attn` is available, flash-attention can be used natively.
if is_flash_attn_2_available():
return True
# flash-attention can be used on Ascend NPU without package `flash-attn`
if is_torch_npu_available():
return True
return False
def flash_attn_supports_top_left_mask():
"""Determine whether flash-attention uses top-left or down-right mask"""
if is_flash_attn_3_available():
return False
if is_flash_attn_2_available():
# top-left mask is used in package `flash-attn` with version lower than 2.1.0
return not is_flash_attn_greater_or_equal_2_10()
if is_torch_npu_available():
# down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask.
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
return is_npu_fa2_top_left_aligned_causal_mask()
return False
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.
Arguments:
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
indices (`torch.Tensor`):
The indices of non-masked tokens from the flattened input sequence.
@@ -229,10 +123,8 @@ def _upad_input(
):
"""
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
tensors for query, key, value tensors.
Arguments:
query_layer (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
@@ -246,7 +138,6 @@ def _upad_input(
Target length.
unpad_input_func:
The function to use for unpadding the input tensors.
Return:
query_layer (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
@@ -299,14 +190,12 @@ def _upad_input(
)
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
def _prepare_from_posids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
@@ -316,7 +205,6 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
@@ -331,19 +219,22 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.view(-1, query.size(-2), query.size(-1))
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
cu_seqlens_k = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0
)
max_k = torch.max(position_ids, dim=1).values.max().item() + 1
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor([0], device=position_ids.device, dtype=torch.int32),
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
@@ -353,61 +244,101 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length = cu_seq_lens.diff().max().item()
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
return (query, key, value, indices_q, (cu_seq_lens, cu_seqlens_k), (max_length, max_k))
def prepare_fa2_from_position_ids(*args, **kwargs):
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
warnings.warn(
"The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.",
"prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids",
FutureWarning,
)
return _prepare_flash_attention_from_position_ids(*args, **kwargs)
return _prepare_from_posids(query, key, value, position_ids)
def fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None):
if target_dtype and q.dtype == torch.float32:
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
return q, k, v
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
"""
if target_dtype is None:
return query, key, value
input_dtype = query.dtype
if input_dtype == torch.float32:
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
def _lazy_imports(impl: Optional[str]):
# returns funcs and pad/unpad based on impl
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
is_fa3 = is_flash_attn_3_available()
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False
except ImportError as e:
if not globals().get("use_remote_fa2", None):
use_remote_fa2 = (
input(
"Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? "
)
.strip()
.lower()
)
globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"}
if globals()["use_remote_fa2"]:
if not is_kernels_available():
raise ImportError("You need to install kernels: `pip install kernels`")
from kernels import get_kernel
impl = get_kernel("kernels-community/flash-attn")
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
return (
getattr(impl, "flash_attn_func", None),
getattr(impl, "flash_attn_varlen_func"),
pad_input,
unpad_input,
True,
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
else:
raise ImportError(
"Failed to import flash attention 2, please install it or use another implementation."
) from e
if impl == "flash_attention_3" or (impl is None and is_fa3):
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
return query, key, value
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True
else:
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
return (
getattr(impl, "flash_attn_func", None),
getattr(impl, "flash_attn_varlen_func"),
pad_input,
unpad_input,
True,
)
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = None
_flash_supports_window = None
def is_flash_attn_available():
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
def flash_attn_supports_top_left_mask():
if is_flash_attn_3_available():
return False
if is_flash_attn_2_available():
return not is_flash_attn_greater_or_equal_2_10()
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
return is_npu_fa2_top_left_aligned_causal_mask()
class FlashAttentionKwargs(TypedDict, total=False):
cumulative_seqlens_q: Optional[torch.LongTensor]
cumulative_seqlens_k: Optional[torch.LongTensor]
def _flash_attention_forward(
@@ -429,185 +360,100 @@ def _flash_attention_forward(
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
target_dtype: Optional[torch.dtype] = None,
attn_implementation: Optional[str] = None,
implementation: Optional[str] = None,
**kwargs,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`, *optional*):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
use_top_left_mask (`bool`, defaults to `False`):
flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
softcap (`float`, *optional*):
Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
attn_implementation (`str`, *optional*):
The attention implementation to use. If None, will default to the one based on the environment.
"""
if attn_implementation is None:
_flash_attn_varlen_func = flash_attn_varlen_func
_flash_attn_func = flash_attn_func
_pad_input = pad_input
_unpad_input = unpad_input
_is_fa3 = HAS_FA3
elif attn_implementation == "flash_attention_3":
_flash_attn_varlen_func = flash_attn_3_varlen_func
_flash_attn_func = flash_attn_3_func
_pad_input = pad_input_fa3
_unpad_input = unpad_input_fa3
_is_fa3 = True
elif attn_implementation == "flash_attention_2":
_flash_attn_varlen_func = flash_attn_2_varlen_func
_flash_attn_func = flash_attn_2_func
_pad_input = pad_input_fa2
_unpad_input = unpad_input_fa2
_is_fa3 = False
if not use_top_left_mask:
causal = is_causal
if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")):
flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation)
globals()["_flash_fn"] = flash_fn
globals()["_flash_varlen_fn"] = flash_varlen_fn
globals()["_pad_fn"] = pad_fn
globals()["_unpad_fn"] = unpad_fn
globals()["_is_fa3"] = is_fa3
flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters
globals()["_flash_supports_window"] = flash_supports_window
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
causal = is_causal and query_length != 1
flash_fn = globals()["_flash_fn"]
flash_varlen_fn = globals()["_flash_varlen_fn"]
pad_fn = globals()["_pad_fn"]
unpad_fn = globals()["_unpad_fn"]
is_fa3 = globals()["_is_fa3"]
flash_supports_window = globals()["_flash_supports_window"]
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
causal = is_causal and not (use_top_left_mask and query_length == 1)
use_sw = (
(_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if _is_fa3:
if dropout > 0.0:
logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.")
else:
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {}
if not is_fa3:
flash_kwargs["dropout_p"] = dropout
if flash_241:
if deterministic is None:
global deterministic_g
if deterministic_g is None:
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
deterministic = deterministic_g
flash_kwargs["deterministic"] = deterministic
if is_flash_attn_greater_or_equal("2.4.1"):
det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = det
if softcap is not None:
flash_kwargs["softcap"] = softcap
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)
# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
# under two cases:
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
is_fa2_with_position_ids = (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
)
is_fa2_with_varlen_kwargs = all(
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
)
# Contains at least one padding token in the sequence
use_mask = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query_states, key_states, value_states, attention_mask, query_length, _unpad_input
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
query_states, key_states, value_states, attention_mask, query_length, unpad_fn
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = _flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
# TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p
if "mps" in str(q.device):
cu_k = cu_k.clone()
out_unpad = flash_varlen_fn(
q,
k,
v,
cu_seqlens_q=cu_q.to(torch.int32),
cu_seqlens_k=cu_k.to(torch.int32),
max_seqlen_q=mq,
max_seqlen_k=mk,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
batch_size = query_states.size(0)
if isinstance(out_unpad, tuple):
out_unpad = out_unpad[0]
out = pad_fn(out_unpad, idx, query_states.shape[0], query_length)
elif use_mask:
if cu_seq_lens_q is None or cu_seq_lens_k is None:
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
_prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids)
if position_ids is None:
raise ValueError(
"Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
)
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
query_states, key_states, value_states, position_ids
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
attn_output = _flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
mq, mk = max_length_q, max_length_k
cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k
if "mps" in str(q.device):
cu_k = cu_k.clone()
out = flash_varlen_fn(
q,
k,
v,
cu_seqlens_q=cu_q.to(torch.int32),
cu_seqlens_k=cu_k.to(torch.int32),
max_seqlen_q=mq,
max_seqlen_k=mk,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
if isinstance(out, tuple):
out = out[0]
out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1))
else:
attn_output = _flash_attn_func(
out = flash_fn(
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)
if isinstance(attn_output, tuple):
return attn_output[0]
return attn_output
class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.
Attributes:
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for query state.
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""
cumulative_seqlens_q: Optional[torch.LongTensor]
cumulative_seqlens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]
return out[0] if isinstance(out, tuple) else out

View File

@@ -72,6 +72,7 @@ from .integrations.tensor_parallel import (
verify_tp_plan,
)
from .loss.loss_utils import LOSS_MAPPING
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
@@ -2785,30 +2786,38 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
None to sdpa (to potentially eager).
"""
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation):
if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = applicable_attn_implementation.split(":")
if ":" in applicable_attn_implementation:
repo_id, kernel_name = attn_implementation.split(":")
kernel_name = kernel_name.strip()
else:
repo_id = attn_implementation
kernel_name = None
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
if hasattr(kernel, "flash_attn_varlen_func"):
ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial(
flash_attention_forward, implementation=kernel
)
elif kernel_name is not None:
ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name)
ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[
"flash_attention_2"
]
applicable_attn_implementation = repo_id
except FileNotFoundError as e:
logger.warning_once(
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
"default attention implementation instead (sdpa if available, eager otherwise)."
)
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. Please check "
"the documentation for the correct format, and check that the kernel exports the class and the function correctly."
)
finally:
return applicable_attn_implementation
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
message = (
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '

View File

@@ -104,6 +104,7 @@ from .utils import (
is_jinja_available,
is_jumanpp_available,
is_keras_nlp_available,
is_kernels_available,
is_levenshtein_available,
is_librosa_available,
is_liger_kernel_available,
@@ -586,6 +587,16 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
def require_kernels(test_case):
"""
Decorator marking a test that requires Flash Attention.
These tests are skipped when Flash Attention isn't installed.
"""
return unittest.skipUnless(is_kernels_available(), "test requires Flash Attention")(test_case)
def require_flash_attn_3(test_case):
"""
Decorator marking a test that requires Flash Attention 3.
@@ -1103,6 +1114,11 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_torch_mps(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case)
def require_large_cpu_ram(test_case, memory: float = 80):
"""Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
if not is_psutil_available():

View File

@@ -1142,10 +1142,14 @@ def get_placeholders_dict(placeholders: list, model_name: str) -> dict:
for placeholder in placeholders:
# Infer placeholders from the model name and the auto modules
if placeholder in PLACEHOLDER_TO_AUTO_MODULE:
try:
place_holder_value = getattr(
getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]),
PLACEHOLDER_TO_AUTO_MODULE[placeholder][1],
).get(model_name, None)
except ImportError:
# In case a library is not installed, we don't want to fail the docstring generation
place_holder_value = None
if place_holder_value is not None:
if isinstance(place_holder_value, (list, tuple)):
place_holder_value = place_holder_value[0]
@@ -1170,8 +1174,11 @@ def format_args_docstring(docstring, model_name):
placeholders_dict = get_placeholders_dict(placeholders, model_name)
# replace the placeholders in the docstring with the values from the placeholders_dict
for placeholder, value in placeholders_dict.items():
if placeholder is not None:
try:
docstring = docstring.replace(f"{{{placeholder}}}", value)
except Exception:
pass
return docstring

View File

@@ -86,12 +86,14 @@ from transformers.testing_utils import (
require_deepspeed,
require_flash_attn,
require_flash_attn_3,
require_kernels,
require_non_hpu,
require_safetensors,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_greater_or_equal,
require_torch_mps,
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
@@ -3474,18 +3476,10 @@ class ModelTesterMixin:
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.head_dim = 64 # fa2 does not always support arbitrary headim
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
@@ -3504,15 +3498,16 @@ class ModelTesterMixin:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
model.set_attn_implementation(attn_implementation)
outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
model.set_attn_implementation(attn_implementation)
outputs_fa = model(dummy_input, output_hidden_states=True)
model.set_attn_implementation("sdpa")
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
@@ -3532,7 +3527,8 @@ class ModelTesterMixin:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
model.set_attn_implementation(attn_implementation)
outputs_fa = model(dummy_input, **other_inputs)
else:
other_inputs = {
"output_hidden_states": True,
@@ -3541,12 +3537,12 @@ class ModelTesterMixin:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
model.set_attn_implementation(attn_implementation)
outputs_fa = model(dummy_input, **other_inputs)
model.set_attn_implementation("sdpa")
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
@@ -3559,10 +3555,29 @@ class ModelTesterMixin:
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
model.set_attn_implementation(attn_implementation)
_ = model(dummy_input, **other_inputs)
else:
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_kernels
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_kernels_inference_equivalence(self):
self.flash_attn_inference_equivalence(attn_implementation="kernels-community/flash-attn3", padding_side="left")
@require_torch_mps
@require_kernels
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_kernels_mps_inference_equivalence(self):
self.flash_attn_inference_equivalence(
attn_implementation="kernels-community/metal-flash-sdpa", padding_side="left"
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test