From efceeaf2678678553e94dce78859f87776e633a7 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 22 Jul 2025 15:41:06 +0200 Subject: [PATCH] Kernels flash attn (#39474) * use partial to wrap around `transformers` utils! * try to refactor? * revert one wrong change * just a nit * push * reverter watever was wrong! * some nits * fixes when there is no attention mask * bring the licence back * some fixes * nit * style * remove prints * correct dtype * fa flags for testing * update * use paged attention if requested! * updates * a clone was needed, not sure why * automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute * simplify and improve? * flash attention is kinda broken on recent cuda version so allow the opportunity to use something else * fix! * protect kernels import * update * properly parse generation config being passed * revert and update * add two tests * some fixes * fix test FA2 * takes comment into account * fixup * revert changes * revert the clone, it is only needed because the metal kernel is not doing it? * [docs] update attention implementation and cache docs (#39547) * update docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * applu suggestions --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix mps on our side for now * Update src/transformers/integrations/flash_paged.py * no qa --------- Co-authored-by: Vasqu Co-authored-by: Raushan Turganbay Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../generation/continuous_batching.py | 3 +- src/transformers/generation/utils.py | 18 + .../integrations/flash_attention.py | 2 +- src/transformers/integrations/flash_paged.py | 13 +- .../modeling_flash_attention_utils.py | 476 ++++++------------ src/transformers/modeling_utils.py | 31 +- src/transformers/testing_utils.py | 16 + src/transformers/utils/auto_docstring.py | 19 +- tests/test_modeling_common.py | 179 ++++--- 9 files changed, 336 insertions(+), 421 deletions(-) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 09ee1fe8ce..e462e483c2 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -1119,7 +1119,8 @@ class ContinuousBatchingManager: self._request_lock = threading.Lock() self.model.generation_config.top_p = None self.do_sample = getattr(generation_config, "do_sample", True) - self.logit_processor = self.model._get_logits_processor(self.model.generation_config) + generation_config = model.generation_config if generation_config is None else generation_config + self.logit_processor = self.model._get_logits_processor(generation_config) self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True) self.profile = getattr(generation_config, "profile", False) self.manual_eviction = manual_eviction diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e360acdac3..3bffb5fdda 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -677,6 +677,24 @@ class GenerationMixin(ContinuousMixin): if encoder_attention_mask is not None: model_inputs["attention_mask"] = encoder_attention_mask + if "flash" in self.config._attn_implementation and self._supports_attention_backend: + tensor_kws = {"dtype": torch.int32, "device": self.device} + pos = model_inputs["position_ids"][:, -1] + + cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0) + max_length_k = int(pos.max()) + 1 + + bs, seq_len = input_ids.size() + q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1) + cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0) + max_length_q = int(q_len.max()) + + model_inputs.update( + cu_seq_lens_q=cu_seq_lens_q.to(self.device), + cu_seq_lens_k=cu_seq_lens_k.to(self.device), + max_length_q=max_length_q, + max_length_k=max_length_k, + ) # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): if key not in model_inputs: diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 00df0ef0fd..43c65b46c8 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -38,7 +38,6 @@ def flash_attention_forward( "FlashAttention does not support inputs with dim=0.\n" "Please check your input shapes or use SDPA instead." ) - # FA2 uses non-transposed inputs query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -76,6 +75,7 @@ def flash_attention_forward( use_top_left_mask=_use_top_left_mask, target_dtype=target_dtype, attn_implementation=module.config._attn_implementation, + layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, **kwargs, ) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index b0463d9524..236e216b3f 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -5,7 +5,7 @@ from ..utils import is_flash_attn_2_available if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func + from flash_attn import flash_attn_varlen_func # noqa: F401 def paged_attention_forward( @@ -20,6 +20,7 @@ def paged_attention_forward( max_seqlen_q=None, max_seqlen_k=None, block_tables=None, + implementation=None, **kwargs, ) -> torch.Tensor: r"""Perform the forward pass of attention with paged key-value cache. @@ -46,12 +47,14 @@ def paged_attention_forward( """ k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs) + if implementation is not None: + flash_attn_varlen_func = implementation.flash_attn_varlen_func attn_output = flash_attn_varlen_func( - q.transpose(1, 2).squeeze(0), - k.transpose(1, 2).squeeze(0), - v.transpose(1, 2).squeeze(0), + q.transpose(1, 2).squeeze(0).contiguous(), + k.transpose(1, 2).squeeze(0).contiguous(), + v.transpose(1, 2).squeeze(0).contiguous(), cumulative_seqlens_q.to(torch.int32), - cumulative_seqlens_k.to(torch.int32), + cumulative_seqlens_k.to(torch.int32).clone(), max_seqlen_q, max_seqlen_k, softmax_scale=module.scaling, diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 1b5476b0ec..848c2a2141 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import inspect import os import warnings @@ -20,6 +19,8 @@ from typing import Optional, TypedDict import torch import torch.nn.functional as F +from transformers.utils.import_utils import is_kernels_available + from .utils import ( is_flash_attn_2_available, is_flash_attn_3_available, @@ -31,25 +32,16 @@ from .utils import ( logger = logging.get_logger(__name__) -flash_attn_func = None -def _index_first_axis(tensor, indices): - """ - A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, - after flattening the first two dimensions of the tensor. This is functionally equivalent to - FA2's `index_first_axis` and replaces the need to import it. - """ - # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first - # two dimensions to get (total_tokens, ...) before indexing. - reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) - return reshaped_tensor[indices] +def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:]) + return reshaped[indices] def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): """ FA3-compatible unpad_input function. - Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. @@ -80,7 +72,6 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): def _fa3_pad_input(hidden_states, indices, batch, seqlen): """ FA3-compatible pad_input function. - Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. @@ -95,109 +86,12 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen): return output.view(batch, seqlen, *dim) -FA_VERSION = None -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func as flash_attn_2_func - from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func - from flash_attn.bert_padding import pad_input as pad_input_fa2 - from flash_attn.bert_padding import unpad_input as unpad_input_fa2 - from flash_attn.layers.rotary import apply_rotary_emb - - HAS_FA2 = True - FA_VERSION = 2 -elif is_torch_npu_available(): - # patch functions in package `flash-attn` when using flash-attention on Ascend NPU. - from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401 - from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func - from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func - from .integrations.npu_flash_attention import pad_input as pad_input_fa2 - from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2 - - HAS_FA2 = True - FA_VERSION = 2 -else: - flash_attn_2_func = None - flash_attn_2_varlen_func = None - pad_input_fa2 = None - unpad_input_fa2 = None - apply_rotary_emb = None - HAS_FA2 = False - -if is_flash_attn_3_available(): - from flash_attn_interface import flash_attn_func as flash_attn_3_func - from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func - - pad_input_fa3 = _fa3_pad_input - unpad_input_fa3 = _fa3_unpad_input - HAS_FA3 = True - FA_VERSION = 3 -else: - flash_attn_3_func = None - flash_attn_3_varlen_func = None - pad_input_fa3 = None - unpad_input_fa3 = None - HAS_FA3 = False - - -# Current Flash Attention implementations -if FA_VERSION: - flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"] - flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"] - unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"] - pad_input = globals()[f"pad_input_fa{FA_VERSION}"] - - -_flash_supports_window_size = False - - -if flash_attn_func: - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -def is_flash_attn_available(): - """Determine whether flash-attention can be used or not.""" - - if is_flash_attn_3_available(): - return True - - # if package `flash-attn` is available, flash-attention can be used natively. - if is_flash_attn_2_available(): - return True - - # flash-attention can be used on Ascend NPU without package `flash-attn` - if is_torch_npu_available(): - return True - - return False - - -def flash_attn_supports_top_left_mask(): - """Determine whether flash-attention uses top-left or down-right mask""" - - if is_flash_attn_3_available(): - return False - - if is_flash_attn_2_available(): - # top-left mask is used in package `flash-attn` with version lower than 2.1.0 - return not is_flash_attn_greater_or_equal_2_10() - - if is_torch_npu_available(): - # down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask. - from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask - - return is_npu_fa2_top_left_aligned_causal_mask() - - return False - - def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. - Arguments: attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - Return: indices (`torch.Tensor`): The indices of non-masked tokens from the flattened input sequence. @@ -229,10 +123,8 @@ def _upad_input( ): """ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. - Arguments: query_layer (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). @@ -246,7 +138,6 @@ def _upad_input( Target length. unpad_input_func: The function to use for unpadding the input tensors. - Return: query_layer (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). @@ -299,14 +190,12 @@ def _upad_input( ) -def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): +def _prepare_from_posids(query, key, value, position_ids): """ This function returns necessary arguments to call `flash_attn_varlen_func`. All three query, key, value states will be flattened. Cumulative lengths of each examples in the batch will be extracted from position_ids. - NOTE: ideally cumulative lengths should be prepared at the data collator stage - Arguments: query (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). @@ -316,7 +205,6 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). position_ids (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - Return: query (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). @@ -331,19 +219,22 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ - query = query.view(-1, query.size(-2), query.size(-1)) + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + cu_seqlens_k = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0 + ) + max_k = torch.max(position_ids, dim=1).values.max().item() + 1 position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) cu_seq_lens = torch.cat( ( - indices_q[position_ids == 0], + torch.tensor([0], device=position_ids.device, dtype=torch.int32), torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), ) ) - # NOTE: With torch compile, this will cause a graph break if you don't set # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. @@ -353,61 +244,101 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing # for some models (e.g. qwen2-vl). max_length = cu_seq_lens.diff().max().item() - - return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + return (query, key, value, indices_q, (cu_seq_lens, cu_seqlens_k), (max_length, max_k)) -def prepare_fa2_from_position_ids(*args, **kwargs): +def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): warnings.warn( - "The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.", + "prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids", FutureWarning, ) - return _prepare_flash_attention_from_position_ids(*args, **kwargs) + return _prepare_from_posids(query, key, value, position_ids) -def fa_peft_integration_check( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - target_dtype: Optional[torch.dtype] = None, -): - """ - PEFT usually casts the layer norms in float32 for training stability reasons - therefore the input hidden states gets silently casted in float32. Hence, we need - cast them back in float16 / bfloat16 just to be sure everything works as expected. - This might slowdown training & inference so it is recommended to not cast the LayerNorms! +def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None): + if target_dtype and q.dtype == torch.float32: + logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.") + q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) + return q, k, v - Args: - query (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value (`torch.Tensor`): - Input value states to be passed to Flash Attention API - target_dtype (`torch.dtype`, *optional*): - The dtype to convert the attention tensors to. Conversion can be ignored by - not providing the target dtype. - """ - if target_dtype is None: - return query, key, value - input_dtype = query.dtype - if input_dtype == torch.float32: - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." +def _lazy_imports(impl: Optional[str]): + # returns funcs and pad/unpad based on impl + is_fa2 = is_flash_attn_2_available() or is_torch_npu_available() + is_fa3 = is_flash_attn_3_available() + if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input + + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False + + except ImportError as e: + if not globals().get("use_remote_fa2", None): + use_remote_fa2 = ( + input( + "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? " + ) + .strip() + .lower() + ) + globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"} + if globals()["use_remote_fa2"]: + if not is_kernels_available(): + raise ImportError("You need to install kernels: `pip install kernels`") + from kernels import get_kernel + + impl = get_kernel("kernels-community/flash-attn") + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return ( + getattr(impl, "flash_attn_func", None), + getattr(impl, "flash_attn_varlen_func"), + pad_input, + unpad_input, + True, + ) + + else: + raise ImportError( + "Failed to import flash attention 2, please install it or use another implementation." + ) from e + if impl == "flash_attention_3" or (impl is None and is_fa3): + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func + + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True + else: + pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input + return ( + getattr(impl, "flash_attn_func", None), + getattr(impl, "flash_attn_varlen_func"), + pad_input, + unpad_input, + True, ) - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) - return query, key, value +_flash_supports_window = None -flash_241 = is_flash_attn_greater_or_equal("2.4.1") -deterministic_g = None +def is_flash_attn_available(): + return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() + + +def flash_attn_supports_top_left_mask(): + if is_flash_attn_3_available(): + return False + if is_flash_attn_2_available(): + return not is_flash_attn_greater_or_equal_2_10() + + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + + return is_npu_fa2_top_left_aligned_causal_mask() + + +class FlashAttentionKwargs(TypedDict, total=False): + cumulative_seqlens_q: Optional[torch.LongTensor] + cumulative_seqlens_k: Optional[torch.LongTensor] def _flash_attention_forward( @@ -429,185 +360,100 @@ def _flash_attention_forward( max_length_q: Optional[int] = None, max_length_k: Optional[int] = None, target_dtype: Optional[torch.dtype] = None, - attn_implementation: Optional[str] = None, + implementation: Optional[str] = None, **kwargs, ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`, *optional*): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_top_left_mask (`bool`, defaults to `False`): - flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. - softcap (`float`, *optional*): - Softcap for the attention logits, used e.g. in gemma2. - deterministic (`bool`, *optional*): - Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. - attn_implementation (`str`, *optional*): - The attention implementation to use. If None, will default to the one based on the environment. - """ - if attn_implementation is None: - _flash_attn_varlen_func = flash_attn_varlen_func - _flash_attn_func = flash_attn_func - _pad_input = pad_input - _unpad_input = unpad_input - _is_fa3 = HAS_FA3 - elif attn_implementation == "flash_attention_3": - _flash_attn_varlen_func = flash_attn_3_varlen_func - _flash_attn_func = flash_attn_3_func - _pad_input = pad_input_fa3 - _unpad_input = unpad_input_fa3 - _is_fa3 = True - elif attn_implementation == "flash_attention_2": - _flash_attn_varlen_func = flash_attn_2_varlen_func - _flash_attn_func = flash_attn_2_func - _pad_input = pad_input_fa2 - _unpad_input = unpad_input_fa2 - _is_fa3 = False - - if not use_top_left_mask: - causal = is_causal + if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")): + flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) + globals()["_flash_fn"] = flash_fn + globals()["_flash_varlen_fn"] = flash_varlen_fn + globals()["_pad_fn"] = pad_fn + globals()["_unpad_fn"] = unpad_fn + globals()["_is_fa3"] = is_fa3 + flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters + globals()["_flash_supports_window"] = flash_supports_window else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. - causal = is_causal and query_length != 1 + flash_fn = globals()["_flash_fn"] + flash_varlen_fn = globals()["_flash_varlen_fn"] + pad_fn = globals()["_pad_fn"] + unpad_fn = globals()["_unpad_fn"] + is_fa3 = globals()["_is_fa3"] + flash_supports_window = globals()["_flash_supports_window"] - # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). - use_sliding_windows = ( - _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + causal = is_causal and not (use_top_left_mask and query_length == 1) + use_sw = ( + (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - - if _is_fa3: - if dropout > 0.0: - logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.") - else: + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} + if not is_fa3: flash_kwargs["dropout_p"] = dropout - - if flash_241: - if deterministic is None: - global deterministic_g - if deterministic_g is None: - deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - deterministic = deterministic_g - flash_kwargs["deterministic"] = deterministic - + if is_flash_attn_greater_or_equal("2.4.1"): + det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = det if softcap is not None: flash_kwargs["softcap"] = softcap - # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op query_states, key_states, value_states = fa_peft_integration_check( query_states, key_states, value_states, target_dtype ) - - # We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - # under two cases: - # Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing - # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. - # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to - # use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility - # to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information - is_fa2_with_position_ids = ( - position_ids is not None - and query_states.shape[0] == 1 - and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) - ) - is_fa2_with_varlen_kwargs = all( - kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) - ) - - # Contains at least one padding token in the sequence + use_mask = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]) if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( - query_states, key_states, value_states, attention_mask, query_length, _unpad_input + q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( + query_states, key_states, value_states, attention_mask, query_length, unpad_fn ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = _flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, + # TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p + if "mps" in str(q.device): + cu_k = cu_k.clone() + out_unpad = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_q.to(torch.int32), + cu_seqlens_k=cu_k.to(torch.int32), + max_seqlen_q=mq, + max_seqlen_k=mk, softmax_scale=softmax_scale, causal=causal, **flash_kwargs, ) - attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length) - - elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids: - batch_size = query_states.size(0) - + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] + out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) + elif use_mask: if cu_seq_lens_q is None or cu_seq_lens_k is None: - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( - _prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids) + if position_ids is None: + raise ValueError( + "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed." + ) + q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids( + query_states, key_states, value_states, position_ids ) - - cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens - max_length_q, max_length_k = max_seq_lens - else: - query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) - key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) - value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) - - attn_output = _flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seq_lens_q, - cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, + q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + mq, mk = max_length_q, max_length_k + cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k + if "mps" in str(q.device): + cu_k = cu_k.clone() + out = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_q.to(torch.int32), + cu_seqlens_k=cu_k.to(torch.int32), + max_seqlen_q=mq, + max_seqlen_k=mk, softmax_scale=softmax_scale, causal=causal, **flash_kwargs, ) - - attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) - + if isinstance(out, tuple): + out = out[0] + out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) else: - attn_output = _flash_attn_func( + out = flash_fn( query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs ) - if isinstance(attn_output, tuple): - return attn_output[0] - return attn_output - - -class FlashAttentionKwargs(TypedDict, total=False): - """ - Keyword arguments for Flash Attention with Compile. - - Attributes: - cumulative_seqlens_q (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for query state. - cumulative_seqlens_k (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for key state. - max_length_q (`int`, *optional*): - Maximum sequence length for query state. - max_length_k (`int`, *optional*): - Maximum sequence length for key state. - """ - - cumulative_seqlens_q: Optional[torch.LongTensor] - cumulative_seqlens_k: Optional[torch.LongTensor] - max_length_q: Optional[int] - max_length_k: Optional[int] + return out[0] if isinstance(out, tuple) else out diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 56e4145250..f4fd894b32 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -72,6 +72,7 @@ from .integrations.tensor_parallel import ( verify_tp_plan, ) from .loss.loss_utils import LOSS_MAPPING +from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, @@ -2785,30 +2786,38 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH None to sdpa (to potentially eager). """ applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation - if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation): + if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation): if not is_kernels_available(): raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") # Extract repo_id and kernel_name from the string - repo_id, kernel_name = applicable_attn_implementation.split(":") - kernel_name = kernel_name.strip() + if ":" in applicable_attn_implementation: + repo_id, kernel_name = attn_implementation.split(":") + kernel_name = kernel_name.strip() + else: + repo_id = attn_implementation + kernel_name = None repo_id = repo_id.strip() - try: kernel = get_kernel(repo_id) - ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)) - applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}" + if hasattr(kernel, "flash_attn_varlen_func"): + ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial( + flash_attention_forward, implementation=kernel + ) + elif kernel_name is not None: + ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) + ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ + "flash_attention_2" + ] + applicable_attn_implementation = repo_id except FileNotFoundError as e: logger.warning_once( f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " "default attention implementation instead (sdpa if available, eager otherwise)." ) applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case - except AttributeError: - raise ValueError( - "the kernel function name or class specified in the attn_implementation argument is not valid. Please check " - "the documentation for the correct format, and check that the kernel exports the class and the function correctly." - ) + finally: + return applicable_attn_implementation if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): message = ( f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are ' diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1df380b6fd..0e117d71f7 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -104,6 +104,7 @@ from .utils import ( is_jinja_available, is_jumanpp_available, is_keras_nlp_available, + is_kernels_available, is_levenshtein_available, is_librosa_available, is_liger_kernel_available, @@ -586,6 +587,16 @@ def require_flash_attn(test_case): return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) +def require_kernels(test_case): + """ + Decorator marking a test that requires Flash Attention. + + These tests are skipped when Flash Attention isn't installed. + + """ + return unittest.skipUnless(is_kernels_available(), "test requires Flash Attention")(test_case) + + def require_flash_attn_3(test_case): """ Decorator marking a test that requires Flash Attention 3. @@ -1103,6 +1114,11 @@ def require_torch_gpu(test_case): return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) +def require_torch_mps(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case) + + def require_large_cpu_ram(test_case, memory: float = 80): """Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory.""" if not is_psutil_available(): diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 11eb382bda..f277df1af1 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -1142,10 +1142,14 @@ def get_placeholders_dict(placeholders: list, model_name: str) -> dict: for placeholder in placeholders: # Infer placeholders from the model name and the auto modules if placeholder in PLACEHOLDER_TO_AUTO_MODULE: - place_holder_value = getattr( - getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]), - PLACEHOLDER_TO_AUTO_MODULE[placeholder][1], - ).get(model_name, None) + try: + place_holder_value = getattr( + getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]), + PLACEHOLDER_TO_AUTO_MODULE[placeholder][1], + ).get(model_name, None) + except ImportError: + # In case a library is not installed, we don't want to fail the docstring generation + place_holder_value = None if place_holder_value is not None: if isinstance(place_holder_value, (list, tuple)): place_holder_value = place_holder_value[0] @@ -1170,8 +1174,11 @@ def format_args_docstring(docstring, model_name): placeholders_dict = get_placeholders_dict(placeholders, model_name) # replace the placeholders in the docstring with the values from the placeholders_dict for placeholder, value in placeholders_dict.items(): - docstring = docstring.replace(f"{{{placeholder}}}", value) - + if placeholder is not None: + try: + docstring = docstring.replace(f"{{{placeholder}}}", value) + except Exception: + pass return docstring diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5589c8cc0d..9c4c0da4ee 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -86,12 +86,14 @@ from transformers.testing_utils import ( require_deepspeed, require_flash_attn, require_flash_attn_3, + require_kernels, require_non_hpu, require_safetensors, require_torch, require_torch_accelerator, require_torch_gpu, require_torch_greater_or_equal, + require_torch_mps, require_torch_multi_accelerator, require_torch_multi_gpu, require_torch_sdpa, @@ -3474,94 +3476,107 @@ class ModelTesterMixin: self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.head_dim = 64 # fa2 does not always support arbitrary headim model = model_class(config) - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation - ) - model_fa.to(torch_device) + model.to(torch_device) + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) - model.to(torch_device) - - dummy_input = inputs_dict[model.main_input_name][:1] - if dummy_input.dtype in [torch.float32, torch.float16]: - dummy_input = dummy_input.to(torch.bfloat16) - - dummy_attention_mask = inputs_dict.get("attention_mask", None) - - if dummy_attention_mask is not None: - dummy_attention_mask = dummy_attention_mask[:1] - if padding_side == "left": - dummy_attention_mask[:, 1:] = 1 - dummy_attention_mask[:, :1] = 0 - else: - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - if model.config.is_encoder_decoder: - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - - outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - else: - outputs = model(dummy_input, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, output_hidden_states=True) - - logits = ( - outputs.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) - - assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) - - if model.config.is_encoder_decoder: - other_inputs = { - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask - - outputs = model(dummy_input, **other_inputs) - outputs_fa = model_fa(dummy_input, **other_inputs) - else: - other_inputs = { - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask - - outputs = model(dummy_input, **other_inputs) - outputs_fa = model_fa(dummy_input, **other_inputs) - - logits = ( - outputs.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] if padding_side == "left": - assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) - - # check with inference + dropout - model.train() - _ = model_fa(dummy_input, **other_inputs) + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 else: - assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, output_hidden_states=True) + + model.set_attn_implementation("sdpa") + logits = ( + outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + model.set_attn_implementation(attn_implementation) + outputs_fa = model(dummy_input, **other_inputs) + + model.set_attn_implementation("sdpa") + logits = ( + outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + if padding_side == "left": + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + model.set_attn_implementation(attn_implementation) + _ = model(dummy_input, **other_inputs) + else: + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + + @require_kernels + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_kernels_inference_equivalence(self): + self.flash_attn_inference_equivalence(attn_implementation="kernels-community/flash-attn3", padding_side="left") + + @require_torch_mps + @require_kernels + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_kernels_mps_inference_equivalence(self): + self.flash_attn_inference_equivalence( + attn_implementation="kernels-community/metal-flash-sdpa", padding_side="left" + ) @require_flash_attn @require_torch_gpu