Kernels flash attn (#39474)
* use partial to wrap around `transformers` utils! * try to refactor? * revert one wrong change * just a nit * push * reverter watever was wrong! * some nits * fixes when there is no attention mask * bring the licence back * some fixes * nit * style * remove prints * correct dtype * fa flags for testing * update * use paged attention if requested! * updates * a clone was needed, not sure why * automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute * simplify and improve? * flash attention is kinda broken on recent cuda version so allow the opportunity to use something else * fix! * protect kernels import * update * properly parse generation config being passed * revert and update * add two tests * some fixes * fix test FA2 * takes comment into account * fixup * revert changes * revert the clone, it is only needed because the metal kernel is not doing it? * [docs] update attention implementation and cache docs (#39547) * update docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * applu suggestions --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix mps on our side for now * Update src/transformers/integrations/flash_paged.py * no qa --------- Co-authored-by: Vasqu <antonprogamer@gmail.com> Co-authored-by: Raushan Turganbay <raushan@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -1119,7 +1119,8 @@ class ContinuousBatchingManager:
|
|||||||
self._request_lock = threading.Lock()
|
self._request_lock = threading.Lock()
|
||||||
self.model.generation_config.top_p = None
|
self.model.generation_config.top_p = None
|
||||||
self.do_sample = getattr(generation_config, "do_sample", True)
|
self.do_sample = getattr(generation_config, "do_sample", True)
|
||||||
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
|
generation_config = model.generation_config if generation_config is None else generation_config
|
||||||
|
self.logit_processor = self.model._get_logits_processor(generation_config)
|
||||||
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
|
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
|
||||||
self.profile = getattr(generation_config, "profile", False)
|
self.profile = getattr(generation_config, "profile", False)
|
||||||
self.manual_eviction = manual_eviction
|
self.manual_eviction = manual_eviction
|
||||||
|
|||||||
@@ -677,6 +677,24 @@ class GenerationMixin(ContinuousMixin):
|
|||||||
if encoder_attention_mask is not None:
|
if encoder_attention_mask is not None:
|
||||||
model_inputs["attention_mask"] = encoder_attention_mask
|
model_inputs["attention_mask"] = encoder_attention_mask
|
||||||
|
|
||||||
|
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
|
||||||
|
tensor_kws = {"dtype": torch.int32, "device": self.device}
|
||||||
|
pos = model_inputs["position_ids"][:, -1]
|
||||||
|
|
||||||
|
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
|
||||||
|
max_length_k = int(pos.max()) + 1
|
||||||
|
|
||||||
|
bs, seq_len = input_ids.size()
|
||||||
|
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
|
||||||
|
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
|
||||||
|
max_length_q = int(q_len.max())
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
|
||||||
|
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
|
||||||
|
max_length_q=max_length_q,
|
||||||
|
max_length_k=max_length_k,
|
||||||
|
)
|
||||||
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key not in model_inputs:
|
if key not in model_inputs:
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ def flash_attention_forward(
|
|||||||
"FlashAttention does not support inputs with dim=0.\n"
|
"FlashAttention does not support inputs with dim=0.\n"
|
||||||
"Please check your input shapes or use SDPA instead."
|
"Please check your input shapes or use SDPA instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
# FA2 uses non-transposed inputs
|
# FA2 uses non-transposed inputs
|
||||||
query = query.transpose(1, 2)
|
query = query.transpose(1, 2)
|
||||||
key = key.transpose(1, 2)
|
key = key.transpose(1, 2)
|
||||||
@@ -76,6 +75,7 @@ def flash_attention_forward(
|
|||||||
use_top_left_mask=_use_top_left_mask,
|
use_top_left_mask=_use_top_left_mask,
|
||||||
target_dtype=target_dtype,
|
target_dtype=target_dtype,
|
||||||
attn_implementation=module.config._attn_implementation,
|
attn_implementation=module.config._attn_implementation,
|
||||||
|
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from ..utils import is_flash_attn_2_available
|
|||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def paged_attention_forward(
|
def paged_attention_forward(
|
||||||
@@ -20,6 +20,7 @@ def paged_attention_forward(
|
|||||||
max_seqlen_q=None,
|
max_seqlen_q=None,
|
||||||
max_seqlen_k=None,
|
max_seqlen_k=None,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
|
implementation=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""Perform the forward pass of attention with paged key-value cache.
|
r"""Perform the forward pass of attention with paged key-value cache.
|
||||||
@@ -46,12 +47,14 @@ def paged_attention_forward(
|
|||||||
"""
|
"""
|
||||||
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
|
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
|
||||||
|
|
||||||
|
if implementation is not None:
|
||||||
|
flash_attn_varlen_func = implementation.flash_attn_varlen_func
|
||||||
attn_output = flash_attn_varlen_func(
|
attn_output = flash_attn_varlen_func(
|
||||||
q.transpose(1, 2).squeeze(0),
|
q.transpose(1, 2).squeeze(0).contiguous(),
|
||||||
k.transpose(1, 2).squeeze(0),
|
k.transpose(1, 2).squeeze(0).contiguous(),
|
||||||
v.transpose(1, 2).squeeze(0),
|
v.transpose(1, 2).squeeze(0).contiguous(),
|
||||||
cumulative_seqlens_q.to(torch.int32),
|
cumulative_seqlens_q.to(torch.int32),
|
||||||
cumulative_seqlens_k.to(torch.int32),
|
cumulative_seqlens_k.to(torch.int32).clone(),
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
softmax_scale=module.scaling,
|
softmax_scale=module.scaling,
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -20,6 +19,8 @@ from typing import Optional, TypedDict
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from transformers.utils.import_utils import is_kernels_available
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_3_available,
|
is_flash_attn_3_available,
|
||||||
@@ -31,25 +32,16 @@ from .utils import (
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
flash_attn_func = None
|
|
||||||
|
|
||||||
|
|
||||||
def _index_first_axis(tensor, indices):
|
def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:])
|
||||||
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
|
return reshaped[indices]
|
||||||
after flattening the first two dimensions of the tensor. This is functionally equivalent to
|
|
||||||
FA2's `index_first_axis` and replaces the need to import it.
|
|
||||||
"""
|
|
||||||
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
|
|
||||||
# two dimensions to get (total_tokens, ...) before indexing.
|
|
||||||
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
|
|
||||||
return reshaped_tensor[indices]
|
|
||||||
|
|
||||||
|
|
||||||
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
|
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||||
"""
|
"""
|
||||||
FA3-compatible unpad_input function.
|
FA3-compatible unpad_input function.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
hidden_states: (batch, seqlen, ...)
|
hidden_states: (batch, seqlen, ...)
|
||||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||||
@@ -80,7 +72,6 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
|
|||||||
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
||||||
"""
|
"""
|
||||||
FA3-compatible pad_input function.
|
FA3-compatible pad_input function.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
||||||
@@ -95,109 +86,12 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
|||||||
return output.view(batch, seqlen, *dim)
|
return output.view(batch, seqlen, *dim)
|
||||||
|
|
||||||
|
|
||||||
FA_VERSION = None
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
from flash_attn import flash_attn_func as flash_attn_2_func
|
|
||||||
from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func
|
|
||||||
from flash_attn.bert_padding import pad_input as pad_input_fa2
|
|
||||||
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
|
|
||||||
from flash_attn.layers.rotary import apply_rotary_emb
|
|
||||||
|
|
||||||
HAS_FA2 = True
|
|
||||||
FA_VERSION = 2
|
|
||||||
elif is_torch_npu_available():
|
|
||||||
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
|
|
||||||
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401
|
|
||||||
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func
|
|
||||||
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func
|
|
||||||
from .integrations.npu_flash_attention import pad_input as pad_input_fa2
|
|
||||||
from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2
|
|
||||||
|
|
||||||
HAS_FA2 = True
|
|
||||||
FA_VERSION = 2
|
|
||||||
else:
|
|
||||||
flash_attn_2_func = None
|
|
||||||
flash_attn_2_varlen_func = None
|
|
||||||
pad_input_fa2 = None
|
|
||||||
unpad_input_fa2 = None
|
|
||||||
apply_rotary_emb = None
|
|
||||||
HAS_FA2 = False
|
|
||||||
|
|
||||||
if is_flash_attn_3_available():
|
|
||||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
|
||||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
|
||||||
|
|
||||||
pad_input_fa3 = _fa3_pad_input
|
|
||||||
unpad_input_fa3 = _fa3_unpad_input
|
|
||||||
HAS_FA3 = True
|
|
||||||
FA_VERSION = 3
|
|
||||||
else:
|
|
||||||
flash_attn_3_func = None
|
|
||||||
flash_attn_3_varlen_func = None
|
|
||||||
pad_input_fa3 = None
|
|
||||||
unpad_input_fa3 = None
|
|
||||||
HAS_FA3 = False
|
|
||||||
|
|
||||||
|
|
||||||
# Current Flash Attention implementations
|
|
||||||
if FA_VERSION:
|
|
||||||
flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"]
|
|
||||||
flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"]
|
|
||||||
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
|
|
||||||
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]
|
|
||||||
|
|
||||||
|
|
||||||
_flash_supports_window_size = False
|
|
||||||
|
|
||||||
|
|
||||||
if flash_attn_func:
|
|
||||||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
|
||||||
|
|
||||||
|
|
||||||
def is_flash_attn_available():
|
|
||||||
"""Determine whether flash-attention can be used or not."""
|
|
||||||
|
|
||||||
if is_flash_attn_3_available():
|
|
||||||
return True
|
|
||||||
|
|
||||||
# if package `flash-attn` is available, flash-attention can be used natively.
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
return True
|
|
||||||
|
|
||||||
# flash-attention can be used on Ascend NPU without package `flash-attn`
|
|
||||||
if is_torch_npu_available():
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_supports_top_left_mask():
|
|
||||||
"""Determine whether flash-attention uses top-left or down-right mask"""
|
|
||||||
|
|
||||||
if is_flash_attn_3_available():
|
|
||||||
return False
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
# top-left mask is used in package `flash-attn` with version lower than 2.1.0
|
|
||||||
return not is_flash_attn_greater_or_equal_2_10()
|
|
||||||
|
|
||||||
if is_torch_npu_available():
|
|
||||||
# down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask.
|
|
||||||
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
|
|
||||||
|
|
||||||
return is_npu_fa2_top_left_aligned_causal_mask()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
|
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
"""
|
"""
|
||||||
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
indices (`torch.Tensor`):
|
indices (`torch.Tensor`):
|
||||||
The indices of non-masked tokens from the flattened input sequence.
|
The indices of non-masked tokens from the flattened input sequence.
|
||||||
@@ -229,10 +123,8 @@ def _upad_input(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
|
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
|
||||||
|
|
||||||
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
|
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
|
||||||
tensors for query, key, value tensors.
|
tensors for query, key, value tensors.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
query_layer (`torch.Tensor`):
|
query_layer (`torch.Tensor`):
|
||||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||||
@@ -246,7 +138,6 @@ def _upad_input(
|
|||||||
Target length.
|
Target length.
|
||||||
unpad_input_func:
|
unpad_input_func:
|
||||||
The function to use for unpadding the input tensors.
|
The function to use for unpadding the input tensors.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
query_layer (`torch.Tensor`):
|
query_layer (`torch.Tensor`):
|
||||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||||
@@ -299,14 +190,12 @@ def _upad_input(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
def _prepare_from_posids(query, key, value, position_ids):
|
||||||
"""
|
"""
|
||||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||||
All three query, key, value states will be flattened.
|
All three query, key, value states will be flattened.
|
||||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||||
|
|
||||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
query (`torch.Tensor`):
|
query (`torch.Tensor`):
|
||||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||||
@@ -316,7 +205,6 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
|||||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||||
position_ids (`torch.Tensor`):
|
position_ids (`torch.Tensor`):
|
||||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
query (`torch.Tensor`):
|
query (`torch.Tensor`):
|
||||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||||
@@ -331,19 +219,22 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
|||||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||||
"""
|
"""
|
||||||
query = query.view(-1, query.size(-2), query.size(-1))
|
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||||
|
cu_seqlens_k = torch.cat(
|
||||||
|
[torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0
|
||||||
|
)
|
||||||
|
max_k = torch.max(position_ids, dim=1).values.max().item() + 1
|
||||||
position_ids = position_ids.flatten()
|
position_ids = position_ids.flatten()
|
||||||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||||
|
|
||||||
cu_seq_lens = torch.cat(
|
cu_seq_lens = torch.cat(
|
||||||
(
|
(
|
||||||
indices_q[position_ids == 0],
|
torch.tensor([0], device=position_ids.device, dtype=torch.int32),
|
||||||
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||||
@@ -353,61 +244,101 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
|||||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||||
# for some models (e.g. qwen2-vl).
|
# for some models (e.g. qwen2-vl).
|
||||||
max_length = cu_seq_lens.diff().max().item()
|
max_length = cu_seq_lens.diff().max().item()
|
||||||
|
return (query, key, value, indices_q, (cu_seq_lens, cu_seqlens_k), (max_length, max_k))
|
||||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_fa2_from_position_ids(*args, **kwargs):
|
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.",
|
"prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
return _prepare_flash_attention_from_position_ids(*args, **kwargs)
|
return _prepare_from_posids(query, key, value, position_ids)
|
||||||
|
|
||||||
|
|
||||||
def fa_peft_integration_check(
|
def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None):
|
||||||
query: torch.Tensor,
|
if target_dtype and q.dtype == torch.float32:
|
||||||
key: torch.Tensor,
|
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
|
||||||
value: torch.Tensor,
|
q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
|
||||||
target_dtype: Optional[torch.dtype] = None,
|
return q, k, v
|
||||||
):
|
|
||||||
"""
|
|
||||||
PEFT usually casts the layer norms in float32 for training stability reasons
|
|
||||||
therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
||||||
cast them back in float16 / bfloat16 just to be sure everything works as expected.
|
|
||||||
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (`torch.Tensor`):
|
|
||||||
Input query states to be passed to Flash Attention API
|
|
||||||
key (`torch.Tensor`):
|
|
||||||
Input key states to be passed to Flash Attention API
|
|
||||||
value (`torch.Tensor`):
|
|
||||||
Input value states to be passed to Flash Attention API
|
|
||||||
target_dtype (`torch.dtype`, *optional*):
|
|
||||||
The dtype to convert the attention tensors to. Conversion can be ignored by
|
|
||||||
not providing the target dtype.
|
|
||||||
"""
|
|
||||||
if target_dtype is None:
|
|
||||||
return query, key, value
|
|
||||||
|
|
||||||
input_dtype = query.dtype
|
def _lazy_imports(impl: Optional[str]):
|
||||||
if input_dtype == torch.float32:
|
# returns funcs and pad/unpad based on impl
|
||||||
logger.warning_once(
|
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
|
||||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
is_fa3 = is_flash_attn_3_available()
|
||||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
|
||||||
f" {target_dtype}."
|
try:
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
|
||||||
|
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
if not globals().get("use_remote_fa2", None):
|
||||||
|
use_remote_fa2 = (
|
||||||
|
input(
|
||||||
|
"Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? "
|
||||||
|
)
|
||||||
|
.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
|
globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"}
|
||||||
|
if globals()["use_remote_fa2"]:
|
||||||
|
if not is_kernels_available():
|
||||||
|
raise ImportError("You need to install kernels: `pip install kernels`")
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
impl = get_kernel("kernels-community/flash-attn")
|
||||||
|
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||||
|
return (
|
||||||
|
getattr(impl, "flash_attn_func", None),
|
||||||
|
getattr(impl, "flash_attn_varlen_func"),
|
||||||
|
pad_input,
|
||||||
|
unpad_input,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
query = query.to(target_dtype)
|
else:
|
||||||
key = key.to(target_dtype)
|
raise ImportError(
|
||||||
value = value.to(target_dtype)
|
"Failed to import flash attention 2, please install it or use another implementation."
|
||||||
|
) from e
|
||||||
|
if impl == "flash_attention_3" or (impl is None and is_fa3):
|
||||||
|
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||||
|
|
||||||
return query, key, value
|
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||||
|
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True
|
||||||
|
else:
|
||||||
|
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||||
|
return (
|
||||||
|
getattr(impl, "flash_attn_func", None),
|
||||||
|
getattr(impl, "flash_attn_varlen_func"),
|
||||||
|
pad_input,
|
||||||
|
unpad_input,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
|
_flash_supports_window = None
|
||||||
deterministic_g = None
|
|
||||||
|
|
||||||
|
def is_flash_attn_available():
|
||||||
|
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_supports_top_left_mask():
|
||||||
|
if is_flash_attn_3_available():
|
||||||
|
return False
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
return not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
|
||||||
|
|
||||||
|
return is_npu_fa2_top_left_aligned_causal_mask()
|
||||||
|
|
||||||
|
|
||||||
|
class FlashAttentionKwargs(TypedDict, total=False):
|
||||||
|
cumulative_seqlens_q: Optional[torch.LongTensor]
|
||||||
|
cumulative_seqlens_k: Optional[torch.LongTensor]
|
||||||
|
|
||||||
|
|
||||||
def _flash_attention_forward(
|
def _flash_attention_forward(
|
||||||
@@ -429,185 +360,100 @@ def _flash_attention_forward(
|
|||||||
max_length_q: Optional[int] = None,
|
max_length_q: Optional[int] = None,
|
||||||
max_length_k: Optional[int] = None,
|
max_length_k: Optional[int] = None,
|
||||||
target_dtype: Optional[torch.dtype] = None,
|
target_dtype: Optional[torch.dtype] = None,
|
||||||
attn_implementation: Optional[str] = None,
|
implementation: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")):
|
||||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation)
|
||||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
globals()["_flash_fn"] = flash_fn
|
||||||
|
globals()["_flash_varlen_fn"] = flash_varlen_fn
|
||||||
Args:
|
globals()["_pad_fn"] = pad_fn
|
||||||
query_states (`torch.Tensor`):
|
globals()["_unpad_fn"] = unpad_fn
|
||||||
Input query states to be passed to Flash Attention API
|
globals()["_is_fa3"] = is_fa3
|
||||||
key_states (`torch.Tensor`):
|
flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters
|
||||||
Input key states to be passed to Flash Attention API
|
globals()["_flash_supports_window"] = flash_supports_window
|
||||||
value_states (`torch.Tensor`):
|
|
||||||
Input value states to be passed to Flash Attention API
|
|
||||||
attention_mask (`torch.Tensor`, *optional*):
|
|
||||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
|
||||||
position of padding tokens and 1 for the position of non-padding tokens.
|
|
||||||
dropout (`float`):
|
|
||||||
Attention dropout
|
|
||||||
softmax_scale (`float`, *optional*):
|
|
||||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
|
||||||
use_top_left_mask (`bool`, defaults to `False`):
|
|
||||||
flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
|
|
||||||
softcap (`float`, *optional*):
|
|
||||||
Softcap for the attention logits, used e.g. in gemma2.
|
|
||||||
deterministic (`bool`, *optional*):
|
|
||||||
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
|
|
||||||
attn_implementation (`str`, *optional*):
|
|
||||||
The attention implementation to use. If None, will default to the one based on the environment.
|
|
||||||
"""
|
|
||||||
if attn_implementation is None:
|
|
||||||
_flash_attn_varlen_func = flash_attn_varlen_func
|
|
||||||
_flash_attn_func = flash_attn_func
|
|
||||||
_pad_input = pad_input
|
|
||||||
_unpad_input = unpad_input
|
|
||||||
_is_fa3 = HAS_FA3
|
|
||||||
elif attn_implementation == "flash_attention_3":
|
|
||||||
_flash_attn_varlen_func = flash_attn_3_varlen_func
|
|
||||||
_flash_attn_func = flash_attn_3_func
|
|
||||||
_pad_input = pad_input_fa3
|
|
||||||
_unpad_input = unpad_input_fa3
|
|
||||||
_is_fa3 = True
|
|
||||||
elif attn_implementation == "flash_attention_2":
|
|
||||||
_flash_attn_varlen_func = flash_attn_2_varlen_func
|
|
||||||
_flash_attn_func = flash_attn_2_func
|
|
||||||
_pad_input = pad_input_fa2
|
|
||||||
_unpad_input = unpad_input_fa2
|
|
||||||
_is_fa3 = False
|
|
||||||
|
|
||||||
if not use_top_left_mask:
|
|
||||||
causal = is_causal
|
|
||||||
else:
|
else:
|
||||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
|
flash_fn = globals()["_flash_fn"]
|
||||||
causal = is_causal and query_length != 1
|
flash_varlen_fn = globals()["_flash_varlen_fn"]
|
||||||
|
pad_fn = globals()["_pad_fn"]
|
||||||
|
unpad_fn = globals()["_unpad_fn"]
|
||||||
|
is_fa3 = globals()["_is_fa3"]
|
||||||
|
flash_supports_window = globals()["_flash_supports_window"]
|
||||||
|
|
||||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
causal = is_causal and not (use_top_left_mask and query_length == 1)
|
||||||
use_sliding_windows = (
|
use_sw = (
|
||||||
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
(_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window
|
||||||
)
|
)
|
||||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {}
|
||||||
|
if not is_fa3:
|
||||||
if _is_fa3:
|
|
||||||
if dropout > 0.0:
|
|
||||||
logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.")
|
|
||||||
else:
|
|
||||||
flash_kwargs["dropout_p"] = dropout
|
flash_kwargs["dropout_p"] = dropout
|
||||||
|
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||||
if flash_241:
|
det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||||
if deterministic is None:
|
flash_kwargs["deterministic"] = det
|
||||||
global deterministic_g
|
|
||||||
if deterministic_g is None:
|
|
||||||
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
|
||||||
deterministic = deterministic_g
|
|
||||||
flash_kwargs["deterministic"] = deterministic
|
|
||||||
|
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
flash_kwargs["softcap"] = softcap
|
flash_kwargs["softcap"] = softcap
|
||||||
|
|
||||||
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
|
|
||||||
query_states, key_states, value_states = fa_peft_integration_check(
|
query_states, key_states, value_states = fa_peft_integration_check(
|
||||||
query_states, key_states, value_states, target_dtype
|
query_states, key_states, value_states, target_dtype
|
||||||
)
|
)
|
||||||
|
use_mask = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])
|
||||||
# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
|
||||||
# under two cases:
|
|
||||||
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
|
||||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
|
||||||
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
|
|
||||||
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
|
|
||||||
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
|
|
||||||
is_fa2_with_position_ids = (
|
|
||||||
position_ids is not None
|
|
||||||
and query_states.shape[0] == 1
|
|
||||||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
|
||||||
)
|
|
||||||
is_fa2_with_varlen_kwargs = all(
|
|
||||||
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Contains at least one padding token in the sequence
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
batch_size = query_states.shape[0]
|
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
|
||||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
|
query_states, key_states, value_states, attention_mask, query_length, unpad_fn
|
||||||
query_states, key_states, value_states, attention_mask, query_length, _unpad_input
|
|
||||||
)
|
)
|
||||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
# TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p
|
||||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
if "mps" in str(q.device):
|
||||||
|
cu_k = cu_k.clone()
|
||||||
attn_output_unpad = _flash_attn_varlen_func(
|
out_unpad = flash_varlen_fn(
|
||||||
query_states,
|
q,
|
||||||
key_states,
|
k,
|
||||||
value_states,
|
v,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
cu_seqlens_q=cu_q.to(torch.int32),
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
cu_seqlens_k=cu_k.to(torch.int32),
|
||||||
max_seqlen_q=max_seqlen_in_batch_q,
|
max_seqlen_q=mq,
|
||||||
max_seqlen_k=max_seqlen_in_batch_k,
|
max_seqlen_k=mk,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
**flash_kwargs,
|
**flash_kwargs,
|
||||||
)
|
)
|
||||||
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
if isinstance(out_unpad, tuple):
|
||||||
|
out_unpad = out_unpad[0]
|
||||||
elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
|
out = pad_fn(out_unpad, idx, query_states.shape[0], query_length)
|
||||||
batch_size = query_states.size(0)
|
elif use_mask:
|
||||||
|
|
||||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
|
if position_ids is None:
|
||||||
_prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids)
|
raise ValueError(
|
||||||
|
"Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
|
||||||
|
)
|
||||||
|
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
|
||||||
|
query_states, key_states, value_states, position_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
|
||||||
max_length_q, max_length_k = max_seq_lens
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
|
q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
|
||||||
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
|
k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
|
||||||
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
|
v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
|
||||||
|
mq, mk = max_length_q, max_length_k
|
||||||
attn_output = _flash_attn_varlen_func(
|
cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k
|
||||||
query_states,
|
if "mps" in str(q.device):
|
||||||
key_states,
|
cu_k = cu_k.clone()
|
||||||
value_states,
|
out = flash_varlen_fn(
|
||||||
cu_seqlens_q=cu_seq_lens_q,
|
q,
|
||||||
cu_seqlens_k=cu_seq_lens_k,
|
k,
|
||||||
max_seqlen_q=max_length_q,
|
v,
|
||||||
max_seqlen_k=max_length_k,
|
cu_seqlens_q=cu_q.to(torch.int32),
|
||||||
|
cu_seqlens_k=cu_k.to(torch.int32),
|
||||||
|
max_seqlen_q=mq,
|
||||||
|
max_seqlen_k=mk,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
**flash_kwargs,
|
**flash_kwargs,
|
||||||
)
|
)
|
||||||
|
if isinstance(out, tuple):
|
||||||
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
out = out[0]
|
||||||
|
out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1))
|
||||||
else:
|
else:
|
||||||
attn_output = _flash_attn_func(
|
out = flash_fn(
|
||||||
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(attn_output, tuple):
|
return out[0] if isinstance(out, tuple) else out
|
||||||
return attn_output[0]
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionKwargs(TypedDict, total=False):
|
|
||||||
"""
|
|
||||||
Keyword arguments for Flash Attention with Compile.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
|
|
||||||
Gets cumulative sequence length for query state.
|
|
||||||
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
|
|
||||||
Gets cumulative sequence length for key state.
|
|
||||||
max_length_q (`int`, *optional*):
|
|
||||||
Maximum sequence length for query state.
|
|
||||||
max_length_k (`int`, *optional*):
|
|
||||||
Maximum sequence length for key state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
cumulative_seqlens_q: Optional[torch.LongTensor]
|
|
||||||
cumulative_seqlens_k: Optional[torch.LongTensor]
|
|
||||||
max_length_q: Optional[int]
|
|
||||||
max_length_k: Optional[int]
|
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ from .integrations.tensor_parallel import (
|
|||||||
verify_tp_plan,
|
verify_tp_plan,
|
||||||
)
|
)
|
||||||
from .loss.loss_utils import LOSS_MAPPING
|
from .loss.loss_utils import LOSS_MAPPING
|
||||||
|
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
|
||||||
from .pytorch_utils import ( # noqa: F401
|
from .pytorch_utils import ( # noqa: F401
|
||||||
Conv1D,
|
Conv1D,
|
||||||
apply_chunking_to_forward,
|
apply_chunking_to_forward,
|
||||||
@@ -2785,30 +2786,38 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
None to sdpa (to potentially eager).
|
None to sdpa (to potentially eager).
|
||||||
"""
|
"""
|
||||||
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
|
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
|
||||||
if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation):
|
if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
|
||||||
if not is_kernels_available():
|
if not is_kernels_available():
|
||||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||||
|
|
||||||
# Extract repo_id and kernel_name from the string
|
# Extract repo_id and kernel_name from the string
|
||||||
repo_id, kernel_name = applicable_attn_implementation.split(":")
|
if ":" in applicable_attn_implementation:
|
||||||
|
repo_id, kernel_name = attn_implementation.split(":")
|
||||||
kernel_name = kernel_name.strip()
|
kernel_name = kernel_name.strip()
|
||||||
|
else:
|
||||||
|
repo_id = attn_implementation
|
||||||
|
kernel_name = None
|
||||||
repo_id = repo_id.strip()
|
repo_id = repo_id.strip()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kernel = get_kernel(repo_id)
|
kernel = get_kernel(repo_id)
|
||||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
if hasattr(kernel, "flash_attn_varlen_func"):
|
||||||
applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial(
|
||||||
|
flash_attention_forward, implementation=kernel
|
||||||
|
)
|
||||||
|
elif kernel_name is not None:
|
||||||
|
ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name)
|
||||||
|
ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[
|
||||||
|
"flash_attention_2"
|
||||||
|
]
|
||||||
|
applicable_attn_implementation = repo_id
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
||||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||||
)
|
)
|
||||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||||
except AttributeError:
|
finally:
|
||||||
raise ValueError(
|
return applicable_attn_implementation
|
||||||
"the kernel function name or class specified in the attn_implementation argument is not valid. Please check "
|
|
||||||
"the documentation for the correct format, and check that the kernel exports the class and the function correctly."
|
|
||||||
)
|
|
||||||
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||||
message = (
|
message = (
|
||||||
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
||||||
|
|||||||
@@ -104,6 +104,7 @@ from .utils import (
|
|||||||
is_jinja_available,
|
is_jinja_available,
|
||||||
is_jumanpp_available,
|
is_jumanpp_available,
|
||||||
is_keras_nlp_available,
|
is_keras_nlp_available,
|
||||||
|
is_kernels_available,
|
||||||
is_levenshtein_available,
|
is_levenshtein_available,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_liger_kernel_available,
|
is_liger_kernel_available,
|
||||||
@@ -586,6 +587,16 @@ def require_flash_attn(test_case):
|
|||||||
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
|
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_kernels(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires Flash Attention.
|
||||||
|
|
||||||
|
These tests are skipped when Flash Attention isn't installed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_kernels_available(), "test requires Flash Attention")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_flash_attn_3(test_case):
|
def require_flash_attn_3(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires Flash Attention 3.
|
Decorator marking a test that requires Flash Attention 3.
|
||||||
@@ -1103,6 +1114,11 @@ def require_torch_gpu(test_case):
|
|||||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_mps(test_case):
|
||||||
|
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||||
|
return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_large_cpu_ram(test_case, memory: float = 80):
|
def require_large_cpu_ram(test_case, memory: float = 80):
|
||||||
"""Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
|
"""Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
|
||||||
if not is_psutil_available():
|
if not is_psutil_available():
|
||||||
|
|||||||
@@ -1142,10 +1142,14 @@ def get_placeholders_dict(placeholders: list, model_name: str) -> dict:
|
|||||||
for placeholder in placeholders:
|
for placeholder in placeholders:
|
||||||
# Infer placeholders from the model name and the auto modules
|
# Infer placeholders from the model name and the auto modules
|
||||||
if placeholder in PLACEHOLDER_TO_AUTO_MODULE:
|
if placeholder in PLACEHOLDER_TO_AUTO_MODULE:
|
||||||
|
try:
|
||||||
place_holder_value = getattr(
|
place_holder_value = getattr(
|
||||||
getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]),
|
getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]),
|
||||||
PLACEHOLDER_TO_AUTO_MODULE[placeholder][1],
|
PLACEHOLDER_TO_AUTO_MODULE[placeholder][1],
|
||||||
).get(model_name, None)
|
).get(model_name, None)
|
||||||
|
except ImportError:
|
||||||
|
# In case a library is not installed, we don't want to fail the docstring generation
|
||||||
|
place_holder_value = None
|
||||||
if place_holder_value is not None:
|
if place_holder_value is not None:
|
||||||
if isinstance(place_holder_value, (list, tuple)):
|
if isinstance(place_holder_value, (list, tuple)):
|
||||||
place_holder_value = place_holder_value[0]
|
place_holder_value = place_holder_value[0]
|
||||||
@@ -1170,8 +1174,11 @@ def format_args_docstring(docstring, model_name):
|
|||||||
placeholders_dict = get_placeholders_dict(placeholders, model_name)
|
placeholders_dict = get_placeholders_dict(placeholders, model_name)
|
||||||
# replace the placeholders in the docstring with the values from the placeholders_dict
|
# replace the placeholders in the docstring with the values from the placeholders_dict
|
||||||
for placeholder, value in placeholders_dict.items():
|
for placeholder, value in placeholders_dict.items():
|
||||||
|
if placeholder is not None:
|
||||||
|
try:
|
||||||
docstring = docstring.replace(f"{{{placeholder}}}", value)
|
docstring = docstring.replace(f"{{{placeholder}}}", value)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return docstring
|
return docstring
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -86,12 +86,14 @@ from transformers.testing_utils import (
|
|||||||
require_deepspeed,
|
require_deepspeed,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_flash_attn_3,
|
require_flash_attn_3,
|
||||||
|
require_kernels,
|
||||||
require_non_hpu,
|
require_non_hpu,
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_greater_or_equal,
|
require_torch_greater_or_equal,
|
||||||
|
require_torch_mps,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
@@ -3474,18 +3476,10 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.head_dim = 64 # fa2 does not always support arbitrary headim
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model_fa = model_class.from_pretrained(
|
|
||||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
|
|
||||||
)
|
|
||||||
model_fa.to(torch_device)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||||
dummy_input = dummy_input.to(torch.bfloat16)
|
dummy_input = dummy_input.to(torch.bfloat16)
|
||||||
@@ -3504,15 +3498,16 @@ class ModelTesterMixin:
|
|||||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
|
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
|
||||||
|
|
||||||
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||||
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
model.set_attn_implementation(attn_implementation)
|
||||||
|
outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||||
else:
|
else:
|
||||||
outputs = model(dummy_input, output_hidden_states=True)
|
outputs = model(dummy_input, output_hidden_states=True)
|
||||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
model.set_attn_implementation(attn_implementation)
|
||||||
|
outputs_fa = model(dummy_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
model.set_attn_implementation("sdpa")
|
||||||
logits = (
|
logits = (
|
||||||
outputs.hidden_states[-1]
|
outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
|
||||||
if not model.config.is_encoder_decoder
|
|
||||||
else outputs.decoder_hidden_states[-1]
|
|
||||||
)
|
)
|
||||||
logits_fa = (
|
logits_fa = (
|
||||||
outputs_fa.hidden_states[-1]
|
outputs_fa.hidden_states[-1]
|
||||||
@@ -3532,7 +3527,8 @@ class ModelTesterMixin:
|
|||||||
other_inputs["attention_mask"] = dummy_attention_mask
|
other_inputs["attention_mask"] = dummy_attention_mask
|
||||||
|
|
||||||
outputs = model(dummy_input, **other_inputs)
|
outputs = model(dummy_input, **other_inputs)
|
||||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
model.set_attn_implementation(attn_implementation)
|
||||||
|
outputs_fa = model(dummy_input, **other_inputs)
|
||||||
else:
|
else:
|
||||||
other_inputs = {
|
other_inputs = {
|
||||||
"output_hidden_states": True,
|
"output_hidden_states": True,
|
||||||
@@ -3541,12 +3537,12 @@ class ModelTesterMixin:
|
|||||||
other_inputs["attention_mask"] = dummy_attention_mask
|
other_inputs["attention_mask"] = dummy_attention_mask
|
||||||
|
|
||||||
outputs = model(dummy_input, **other_inputs)
|
outputs = model(dummy_input, **other_inputs)
|
||||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
model.set_attn_implementation(attn_implementation)
|
||||||
|
outputs_fa = model(dummy_input, **other_inputs)
|
||||||
|
|
||||||
|
model.set_attn_implementation("sdpa")
|
||||||
logits = (
|
logits = (
|
||||||
outputs.hidden_states[-1]
|
outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
|
||||||
if not model.config.is_encoder_decoder
|
|
||||||
else outputs.decoder_hidden_states[-1]
|
|
||||||
)
|
)
|
||||||
logits_fa = (
|
logits_fa = (
|
||||||
outputs_fa.hidden_states[-1]
|
outputs_fa.hidden_states[-1]
|
||||||
@@ -3559,10 +3555,29 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# check with inference + dropout
|
# check with inference + dropout
|
||||||
model.train()
|
model.train()
|
||||||
_ = model_fa(dummy_input, **other_inputs)
|
model.set_attn_implementation(attn_implementation)
|
||||||
|
_ = model(dummy_input, **other_inputs)
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
|
@require_kernels
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
@is_flaky()
|
||||||
|
def test_flash_attn_kernels_inference_equivalence(self):
|
||||||
|
self.flash_attn_inference_equivalence(attn_implementation="kernels-community/flash-attn3", padding_side="left")
|
||||||
|
|
||||||
|
@require_torch_mps
|
||||||
|
@require_kernels
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
@is_flaky()
|
||||||
|
def test_flash_attn_kernels_mps_inference_equivalence(self):
|
||||||
|
self.flash_attn_inference_equivalence(
|
||||||
|
attn_implementation="kernels-community/metal-flash-sdpa", padding_side="left"
|
||||||
|
)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
|
|||||||
Reference in New Issue
Block a user