Kernels flash attn (#39474)
* use partial to wrap around `transformers` utils! * try to refactor? * revert one wrong change * just a nit * push * reverter watever was wrong! * some nits * fixes when there is no attention mask * bring the licence back * some fixes * nit * style * remove prints * correct dtype * fa flags for testing * update * use paged attention if requested! * updates * a clone was needed, not sure why * automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute * simplify and improve? * flash attention is kinda broken on recent cuda version so allow the opportunity to use something else * fix! * protect kernels import * update * properly parse generation config being passed * revert and update * add two tests * some fixes * fix test FA2 * takes comment into account * fixup * revert changes * revert the clone, it is only needed because the metal kernel is not doing it? * [docs] update attention implementation and cache docs (#39547) * update docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * applu suggestions --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix mps on our side for now * Update src/transformers/integrations/flash_paged.py * no qa --------- Co-authored-by: Vasqu <antonprogamer@gmail.com> Co-authored-by: Raushan Turganbay <raushan@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -1119,7 +1119,8 @@ class ContinuousBatchingManager:
|
||||
self._request_lock = threading.Lock()
|
||||
self.model.generation_config.top_p = None
|
||||
self.do_sample = getattr(generation_config, "do_sample", True)
|
||||
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
|
||||
generation_config = model.generation_config if generation_config is None else generation_config
|
||||
self.logit_processor = self.model._get_logits_processor(generation_config)
|
||||
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
|
||||
self.profile = getattr(generation_config, "profile", False)
|
||||
self.manual_eviction = manual_eviction
|
||||
|
||||
@@ -677,6 +677,24 @@ class GenerationMixin(ContinuousMixin):
|
||||
if encoder_attention_mask is not None:
|
||||
model_inputs["attention_mask"] = encoder_attention_mask
|
||||
|
||||
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
|
||||
tensor_kws = {"dtype": torch.int32, "device": self.device}
|
||||
pos = model_inputs["position_ids"][:, -1]
|
||||
|
||||
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
|
||||
max_length_k = int(pos.max()) + 1
|
||||
|
||||
bs, seq_len = input_ids.size()
|
||||
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
|
||||
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
|
||||
max_length_q = int(q_len.max())
|
||||
|
||||
model_inputs.update(
|
||||
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
|
||||
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
|
||||
max_length_q=max_length_q,
|
||||
max_length_k=max_length_k,
|
||||
)
|
||||
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
||||
for key, value in kwargs.items():
|
||||
if key not in model_inputs:
|
||||
|
||||
@@ -38,7 +38,6 @@ def flash_attention_forward(
|
||||
"FlashAttention does not support inputs with dim=0.\n"
|
||||
"Please check your input shapes or use SDPA instead."
|
||||
)
|
||||
|
||||
# FA2 uses non-transposed inputs
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
@@ -76,6 +75,7 @@ def flash_attention_forward(
|
||||
use_top_left_mask=_use_top_left_mask,
|
||||
target_dtype=target_dtype,
|
||||
attn_implementation=module.config._attn_implementation,
|
||||
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from ..utils import is_flash_attn_2_available
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
|
||||
|
||||
def paged_attention_forward(
|
||||
@@ -20,6 +20,7 @@ def paged_attention_forward(
|
||||
max_seqlen_q=None,
|
||||
max_seqlen_k=None,
|
||||
block_tables=None,
|
||||
implementation=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""Perform the forward pass of attention with paged key-value cache.
|
||||
@@ -46,12 +47,14 @@ def paged_attention_forward(
|
||||
"""
|
||||
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
|
||||
|
||||
if implementation is not None:
|
||||
flash_attn_varlen_func = implementation.flash_attn_varlen_func
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q.transpose(1, 2).squeeze(0),
|
||||
k.transpose(1, 2).squeeze(0),
|
||||
v.transpose(1, 2).squeeze(0),
|
||||
q.transpose(1, 2).squeeze(0).contiguous(),
|
||||
k.transpose(1, 2).squeeze(0).contiguous(),
|
||||
v.transpose(1, 2).squeeze(0).contiguous(),
|
||||
cumulative_seqlens_q.to(torch.int32),
|
||||
cumulative_seqlens_k.to(torch.int32),
|
||||
cumulative_seqlens_k.to(torch.int32).clone(),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale=module.scaling,
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
@@ -20,6 +19,8 @@ from typing import Optional, TypedDict
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.utils.import_utils import is_kernels_available
|
||||
|
||||
from .utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_3_available,
|
||||
@@ -31,25 +32,16 @@ from .utils import (
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
flash_attn_func = None
|
||||
|
||||
|
||||
def _index_first_axis(tensor, indices):
|
||||
"""
|
||||
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
|
||||
after flattening the first two dimensions of the tensor. This is functionally equivalent to
|
||||
FA2's `index_first_axis` and replaces the need to import it.
|
||||
"""
|
||||
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
|
||||
# two dimensions to get (total_tokens, ...) before indexing.
|
||||
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
|
||||
return reshaped_tensor[indices]
|
||||
def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:])
|
||||
return reshaped[indices]
|
||||
|
||||
|
||||
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
"""
|
||||
FA3-compatible unpad_input function.
|
||||
|
||||
Arguments:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||
@@ -80,7 +72,6 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
||||
"""
|
||||
FA3-compatible pad_input function.
|
||||
|
||||
Arguments:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
||||
@@ -95,109 +86,12 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
||||
return output.view(batch, seqlen, *dim)
|
||||
|
||||
|
||||
FA_VERSION = None
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func as flash_attn_2_func
|
||||
from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func
|
||||
from flash_attn.bert_padding import pad_input as pad_input_fa2
|
||||
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
HAS_FA2 = True
|
||||
FA_VERSION = 2
|
||||
elif is_torch_npu_available():
|
||||
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
|
||||
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401
|
||||
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func
|
||||
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func
|
||||
from .integrations.npu_flash_attention import pad_input as pad_input_fa2
|
||||
from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2
|
||||
|
||||
HAS_FA2 = True
|
||||
FA_VERSION = 2
|
||||
else:
|
||||
flash_attn_2_func = None
|
||||
flash_attn_2_varlen_func = None
|
||||
pad_input_fa2 = None
|
||||
unpad_input_fa2 = None
|
||||
apply_rotary_emb = None
|
||||
HAS_FA2 = False
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
|
||||
pad_input_fa3 = _fa3_pad_input
|
||||
unpad_input_fa3 = _fa3_unpad_input
|
||||
HAS_FA3 = True
|
||||
FA_VERSION = 3
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
pad_input_fa3 = None
|
||||
unpad_input_fa3 = None
|
||||
HAS_FA3 = False
|
||||
|
||||
|
||||
# Current Flash Attention implementations
|
||||
if FA_VERSION:
|
||||
flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"]
|
||||
flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"]
|
||||
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
|
||||
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]
|
||||
|
||||
|
||||
_flash_supports_window_size = False
|
||||
|
||||
|
||||
if flash_attn_func:
|
||||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||
|
||||
|
||||
def is_flash_attn_available():
|
||||
"""Determine whether flash-attention can be used or not."""
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
return True
|
||||
|
||||
# if package `flash-attn` is available, flash-attention can be used natively.
|
||||
if is_flash_attn_2_available():
|
||||
return True
|
||||
|
||||
# flash-attention can be used on Ascend NPU without package `flash-attn`
|
||||
if is_torch_npu_available():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def flash_attn_supports_top_left_mask():
|
||||
"""Determine whether flash-attention uses top-left or down-right mask"""
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
return False
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
# top-left mask is used in package `flash-attn` with version lower than 2.1.0
|
||||
return not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
if is_torch_npu_available():
|
||||
# down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask.
|
||||
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
|
||||
|
||||
return is_npu_fa2_top_left_aligned_causal_mask()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""
|
||||
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
||||
|
||||
Arguments:
|
||||
attention_mask (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
|
||||
Return:
|
||||
indices (`torch.Tensor`):
|
||||
The indices of non-masked tokens from the flattened input sequence.
|
||||
@@ -229,10 +123,8 @@ def _upad_input(
|
||||
):
|
||||
"""
|
||||
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
|
||||
|
||||
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
|
||||
tensors for query, key, value tensors.
|
||||
|
||||
Arguments:
|
||||
query_layer (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
@@ -246,7 +138,6 @@ def _upad_input(
|
||||
Target length.
|
||||
unpad_input_func:
|
||||
The function to use for unpadding the input tensors.
|
||||
|
||||
Return:
|
||||
query_layer (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
@@ -299,14 +190,12 @@ def _upad_input(
|
||||
)
|
||||
|
||||
|
||||
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
||||
def _prepare_from_posids(query, key, value, position_ids):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||
|
||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||
|
||||
Arguments:
|
||||
query (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
@@ -316,7 +205,6 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
|
||||
Return:
|
||||
query (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
@@ -331,19 +219,22 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
query = query.view(-1, query.size(-2), query.size(-1))
|
||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||
cu_seqlens_k = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0
|
||||
)
|
||||
max_k = torch.max(position_ids, dim=1).values.max().item() + 1
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor([0], device=position_ids.device, dtype=torch.int32),
|
||||
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||
@@ -353,61 +244,101 @@ def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||
# for some models (e.g. qwen2-vl).
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
|
||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seqlens_k), (max_length, max_k))
|
||||
|
||||
|
||||
def prepare_fa2_from_position_ids(*args, **kwargs):
|
||||
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
||||
warnings.warn(
|
||||
"The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.",
|
||||
"prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids",
|
||||
FutureWarning,
|
||||
)
|
||||
return _prepare_flash_attention_from_position_ids(*args, **kwargs)
|
||||
return _prepare_from_posids(query, key, value, position_ids)
|
||||
|
||||
|
||||
def fa_peft_integration_check(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
target_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
PEFT usually casts the layer norms in float32 for training stability reasons
|
||||
therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
cast them back in float16 / bfloat16 just to be sure everything works as expected.
|
||||
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
|
||||
def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None):
|
||||
if target_dtype and q.dtype == torch.float32:
|
||||
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
|
||||
q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
|
||||
return q, k, v
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
target_dtype (`torch.dtype`, *optional*):
|
||||
The dtype to convert the attention tensors to. Conversion can be ignored by
|
||||
not providing the target dtype.
|
||||
"""
|
||||
if target_dtype is None:
|
||||
return query, key, value
|
||||
|
||||
input_dtype = query.dtype
|
||||
if input_dtype == torch.float32:
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
def _lazy_imports(impl: Optional[str]):
|
||||
# returns funcs and pad/unpad based on impl
|
||||
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
|
||||
is_fa3 = is_flash_attn_3_available()
|
||||
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False
|
||||
|
||||
except ImportError as e:
|
||||
if not globals().get("use_remote_fa2", None):
|
||||
use_remote_fa2 = (
|
||||
input(
|
||||
"Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? "
|
||||
)
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"}
|
||||
if globals()["use_remote_fa2"]:
|
||||
if not is_kernels_available():
|
||||
raise ImportError("You need to install kernels: `pip install kernels`")
|
||||
from kernels import get_kernel
|
||||
|
||||
impl = get_kernel("kernels-community/flash-attn")
|
||||
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||
return (
|
||||
getattr(impl, "flash_attn_func", None),
|
||||
getattr(impl, "flash_attn_varlen_func"),
|
||||
pad_input,
|
||||
unpad_input,
|
||||
True,
|
||||
)
|
||||
|
||||
query = query.to(target_dtype)
|
||||
key = key.to(target_dtype)
|
||||
value = value.to(target_dtype)
|
||||
else:
|
||||
raise ImportError(
|
||||
"Failed to import flash attention 2, please install it or use another implementation."
|
||||
) from e
|
||||
if impl == "flash_attention_3" or (impl is None and is_fa3):
|
||||
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
|
||||
return query, key, value
|
||||
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True
|
||||
else:
|
||||
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||
return (
|
||||
getattr(impl, "flash_attn_func", None),
|
||||
getattr(impl, "flash_attn_varlen_func"),
|
||||
pad_input,
|
||||
unpad_input,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
|
||||
deterministic_g = None
|
||||
_flash_supports_window = None
|
||||
|
||||
|
||||
def is_flash_attn_available():
|
||||
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
|
||||
|
||||
|
||||
def flash_attn_supports_top_left_mask():
|
||||
if is_flash_attn_3_available():
|
||||
return False
|
||||
if is_flash_attn_2_available():
|
||||
return not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
|
||||
|
||||
return is_npu_fa2_top_left_aligned_causal_mask()
|
||||
|
||||
|
||||
class FlashAttentionKwargs(TypedDict, total=False):
|
||||
cumulative_seqlens_q: Optional[torch.LongTensor]
|
||||
cumulative_seqlens_k: Optional[torch.LongTensor]
|
||||
|
||||
|
||||
def _flash_attention_forward(
|
||||
@@ -429,185 +360,100 @@ def _flash_attention_forward(
|
||||
max_length_q: Optional[int] = None,
|
||||
max_length_k: Optional[int] = None,
|
||||
target_dtype: Optional[torch.dtype] = None,
|
||||
attn_implementation: Optional[str] = None,
|
||||
implementation: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`float`):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
use_top_left_mask (`bool`, defaults to `False`):
|
||||
flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
|
||||
softcap (`float`, *optional*):
|
||||
Softcap for the attention logits, used e.g. in gemma2.
|
||||
deterministic (`bool`, *optional*):
|
||||
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
|
||||
attn_implementation (`str`, *optional*):
|
||||
The attention implementation to use. If None, will default to the one based on the environment.
|
||||
"""
|
||||
if attn_implementation is None:
|
||||
_flash_attn_varlen_func = flash_attn_varlen_func
|
||||
_flash_attn_func = flash_attn_func
|
||||
_pad_input = pad_input
|
||||
_unpad_input = unpad_input
|
||||
_is_fa3 = HAS_FA3
|
||||
elif attn_implementation == "flash_attention_3":
|
||||
_flash_attn_varlen_func = flash_attn_3_varlen_func
|
||||
_flash_attn_func = flash_attn_3_func
|
||||
_pad_input = pad_input_fa3
|
||||
_unpad_input = unpad_input_fa3
|
||||
_is_fa3 = True
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
_flash_attn_varlen_func = flash_attn_2_varlen_func
|
||||
_flash_attn_func = flash_attn_2_func
|
||||
_pad_input = pad_input_fa2
|
||||
_unpad_input = unpad_input_fa2
|
||||
_is_fa3 = False
|
||||
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")):
|
||||
flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation)
|
||||
globals()["_flash_fn"] = flash_fn
|
||||
globals()["_flash_varlen_fn"] = flash_varlen_fn
|
||||
globals()["_pad_fn"] = pad_fn
|
||||
globals()["_unpad_fn"] = unpad_fn
|
||||
globals()["_is_fa3"] = is_fa3
|
||||
flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters
|
||||
globals()["_flash_supports_window"] = flash_supports_window
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
|
||||
causal = is_causal and query_length != 1
|
||||
flash_fn = globals()["_flash_fn"]
|
||||
flash_varlen_fn = globals()["_flash_varlen_fn"]
|
||||
pad_fn = globals()["_pad_fn"]
|
||||
unpad_fn = globals()["_unpad_fn"]
|
||||
is_fa3 = globals()["_is_fa3"]
|
||||
flash_supports_window = globals()["_flash_supports_window"]
|
||||
|
||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
||||
causal = is_causal and not (use_top_left_mask and query_length == 1)
|
||||
use_sw = (
|
||||
(_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window
|
||||
)
|
||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
||||
|
||||
if _is_fa3:
|
||||
if dropout > 0.0:
|
||||
logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.")
|
||||
else:
|
||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {}
|
||||
if not is_fa3:
|
||||
flash_kwargs["dropout_p"] = dropout
|
||||
|
||||
if flash_241:
|
||||
if deterministic is None:
|
||||
global deterministic_g
|
||||
if deterministic_g is None:
|
||||
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
deterministic = deterministic_g
|
||||
flash_kwargs["deterministic"] = deterministic
|
||||
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
flash_kwargs["deterministic"] = det
|
||||
if softcap is not None:
|
||||
flash_kwargs["softcap"] = softcap
|
||||
|
||||
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
|
||||
query_states, key_states, value_states = fa_peft_integration_check(
|
||||
query_states, key_states, value_states, target_dtype
|
||||
)
|
||||
|
||||
# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
# under two cases:
|
||||
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
|
||||
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
|
||||
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
|
||||
is_fa2_with_position_ids = (
|
||||
position_ids is not None
|
||||
and query_states.shape[0] == 1
|
||||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
||||
)
|
||||
is_fa2_with_varlen_kwargs = all(
|
||||
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
|
||||
)
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
use_mask = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length, _unpad_input
|
||||
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length, unpad_fn
|
||||
)
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = _flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
# TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p
|
||||
if "mps" in str(q.device):
|
||||
cu_k = cu_k.clone()
|
||||
out_unpad = flash_varlen_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_q.to(torch.int32),
|
||||
cu_seqlens_k=cu_k.to(torch.int32),
|
||||
max_seqlen_q=mq,
|
||||
max_seqlen_k=mk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
|
||||
elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
|
||||
batch_size = query_states.size(0)
|
||||
|
||||
if isinstance(out_unpad, tuple):
|
||||
out_unpad = out_unpad[0]
|
||||
out = pad_fn(out_unpad, idx, query_states.shape[0], query_length)
|
||||
elif use_mask:
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
|
||||
_prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids)
|
||||
if position_ids is None:
|
||||
raise ValueError(
|
||||
"Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
|
||||
)
|
||||
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
|
||||
query_states, key_states, value_states, position_ids
|
||||
)
|
||||
|
||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||
max_length_q, max_length_k = max_seq_lens
|
||||
|
||||
else:
|
||||
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
|
||||
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
|
||||
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
|
||||
|
||||
attn_output = _flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seq_lens_q,
|
||||
cu_seqlens_k=cu_seq_lens_k,
|
||||
max_seqlen_q=max_length_q,
|
||||
max_seqlen_k=max_length_k,
|
||||
q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
|
||||
k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
|
||||
v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
|
||||
mq, mk = max_length_q, max_length_k
|
||||
cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k
|
||||
if "mps" in str(q.device):
|
||||
cu_k = cu_k.clone()
|
||||
out = flash_varlen_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_q.to(torch.int32),
|
||||
cu_seqlens_k=cu_k.to(torch.int32),
|
||||
max_seqlen_q=mq,
|
||||
max_seqlen_k=mk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1))
|
||||
else:
|
||||
attn_output = _flash_attn_func(
|
||||
out = flash_fn(
|
||||
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||
)
|
||||
|
||||
if isinstance(attn_output, tuple):
|
||||
return attn_output[0]
|
||||
return attn_output
|
||||
|
||||
|
||||
class FlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for Flash Attention with Compile.
|
||||
|
||||
Attributes:
|
||||
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
|
||||
Gets cumulative sequence length for query state.
|
||||
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`, *optional*):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`, *optional*):
|
||||
Maximum sequence length for key state.
|
||||
"""
|
||||
|
||||
cumulative_seqlens_q: Optional[torch.LongTensor]
|
||||
cumulative_seqlens_k: Optional[torch.LongTensor]
|
||||
max_length_q: Optional[int]
|
||||
max_length_k: Optional[int]
|
||||
return out[0] if isinstance(out, tuple) else out
|
||||
|
||||
@@ -72,6 +72,7 @@ from .integrations.tensor_parallel import (
|
||||
verify_tp_plan,
|
||||
)
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
@@ -2785,30 +2786,38 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
None to sdpa (to potentially eager).
|
||||
"""
|
||||
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
|
||||
if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation):
|
||||
if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = applicable_attn_implementation.split(":")
|
||||
if ":" in applicable_attn_implementation:
|
||||
repo_id, kernel_name = attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
else:
|
||||
repo_id = attn_implementation
|
||||
kernel_name = None
|
||||
repo_id = repo_id.strip()
|
||||
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
||||
applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
if hasattr(kernel, "flash_attn_varlen_func"):
|
||||
ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial(
|
||||
flash_attention_forward, implementation=kernel
|
||||
)
|
||||
elif kernel_name is not None:
|
||||
ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[
|
||||
"flash_attention_2"
|
||||
]
|
||||
applicable_attn_implementation = repo_id
|
||||
except FileNotFoundError as e:
|
||||
logger.warning_once(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||
)
|
||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. Please check "
|
||||
"the documentation for the correct format, and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
finally:
|
||||
return applicable_attn_implementation
|
||||
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
message = (
|
||||
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
||||
|
||||
@@ -104,6 +104,7 @@ from .utils import (
|
||||
is_jinja_available,
|
||||
is_jumanpp_available,
|
||||
is_keras_nlp_available,
|
||||
is_kernels_available,
|
||||
is_levenshtein_available,
|
||||
is_librosa_available,
|
||||
is_liger_kernel_available,
|
||||
@@ -586,6 +587,16 @@ def require_flash_attn(test_case):
|
||||
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
|
||||
|
||||
|
||||
def require_kernels(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Flash Attention.
|
||||
|
||||
These tests are skipped when Flash Attention isn't installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_kernels_available(), "test requires Flash Attention")(test_case)
|
||||
|
||||
|
||||
def require_flash_attn_3(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Flash Attention 3.
|
||||
@@ -1103,6 +1114,11 @@ def require_torch_gpu(test_case):
|
||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||
|
||||
|
||||
def require_torch_mps(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case)
|
||||
|
||||
|
||||
def require_large_cpu_ram(test_case, memory: float = 80):
|
||||
"""Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
|
||||
if not is_psutil_available():
|
||||
|
||||
@@ -1142,10 +1142,14 @@ def get_placeholders_dict(placeholders: list, model_name: str) -> dict:
|
||||
for placeholder in placeholders:
|
||||
# Infer placeholders from the model name and the auto modules
|
||||
if placeholder in PLACEHOLDER_TO_AUTO_MODULE:
|
||||
try:
|
||||
place_holder_value = getattr(
|
||||
getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]),
|
||||
PLACEHOLDER_TO_AUTO_MODULE[placeholder][1],
|
||||
).get(model_name, None)
|
||||
except ImportError:
|
||||
# In case a library is not installed, we don't want to fail the docstring generation
|
||||
place_holder_value = None
|
||||
if place_holder_value is not None:
|
||||
if isinstance(place_holder_value, (list, tuple)):
|
||||
place_holder_value = place_holder_value[0]
|
||||
@@ -1170,8 +1174,11 @@ def format_args_docstring(docstring, model_name):
|
||||
placeholders_dict = get_placeholders_dict(placeholders, model_name)
|
||||
# replace the placeholders in the docstring with the values from the placeholders_dict
|
||||
for placeholder, value in placeholders_dict.items():
|
||||
if placeholder is not None:
|
||||
try:
|
||||
docstring = docstring.replace(f"{{{placeholder}}}", value)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return docstring
|
||||
|
||||
|
||||
|
||||
@@ -86,12 +86,14 @@ from transformers.testing_utils import (
|
||||
require_deepspeed,
|
||||
require_flash_attn,
|
||||
require_flash_attn_3,
|
||||
require_kernels,
|
||||
require_non_hpu,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_greater_or_equal,
|
||||
require_torch_mps,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_sdpa,
|
||||
@@ -3474,18 +3476,10 @@ class ModelTesterMixin:
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.head_dim = 64 # fa2 does not always support arbitrary headim
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
@@ -3504,15 +3498,16 @@ class ModelTesterMixin:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
|
||||
|
||||
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
else:
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, output_hidden_states=True)
|
||||
|
||||
model.set_attn_implementation("sdpa")
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs.decoder_hidden_states[-1]
|
||||
outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
@@ -3532,7 +3527,8 @@ class ModelTesterMixin:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, **other_inputs)
|
||||
else:
|
||||
other_inputs = {
|
||||
"output_hidden_states": True,
|
||||
@@ -3541,12 +3537,12 @@ class ModelTesterMixin:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, **other_inputs)
|
||||
|
||||
model.set_attn_implementation("sdpa")
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs.decoder_hidden_states[-1]
|
||||
outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
@@ -3559,10 +3555,29 @@ class ModelTesterMixin:
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(dummy_input, **other_inputs)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
_ = model(dummy_input, **other_inputs)
|
||||
else:
|
||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||
|
||||
@require_kernels
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_kernels_inference_equivalence(self):
|
||||
self.flash_attn_inference_equivalence(attn_implementation="kernels-community/flash-attn3", padding_side="left")
|
||||
|
||||
@require_torch_mps
|
||||
@require_kernels
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_kernels_mps_inference_equivalence(self):
|
||||
self.flash_attn_inference_equivalence(
|
||||
attn_implementation="kernels-community/metal-flash-sdpa", padding_side="left"
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user