Kernels flash attn (#39474)

* use partial to wrap around `transformers` utils!

* try to refactor?

* revert one wrong change

* just a nit

* push

* reverter watever was wrong!

* some nits

* fixes when there is no attention mask

* bring the licence back

* some fixes

* nit

* style

* remove prints

* correct dtype

* fa flags for testing

* update

* use paged attention if requested!

* updates

* a clone was needed, not sure why

* automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute

* simplify and improve?

* flash attention is kinda broken on recent cuda version so allow the opportunity to use something else

* fix!

* protect kernels import

* update

* properly parse generation config being passed

* revert and update

* add two tests

* some fixes

* fix test FA2

* takes comment into account

* fixup

* revert changes

* revert the clone, it is only needed because the metal kernel is not doing it?

* [docs] update attention implementation and cache docs (#39547)

* update docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* applu suggestions

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix mps on our side for now

* Update src/transformers/integrations/flash_paged.py

* no qa

---------

Co-authored-by: Vasqu <antonprogamer@gmail.com>
Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Arthur
2025-07-22 15:41:06 +02:00
committed by GitHub
parent b62557e712
commit efceeaf267
9 changed files with 336 additions and 421 deletions

View File

@@ -72,6 +72,7 @@ from .integrations.tensor_parallel import (
verify_tp_plan,
)
from .loss.loss_utils import LOSS_MAPPING
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
@@ -2785,30 +2786,38 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
None to sdpa (to potentially eager).
"""
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation):
if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = applicable_attn_implementation.split(":")
kernel_name = kernel_name.strip()
if ":" in applicable_attn_implementation:
repo_id, kernel_name = attn_implementation.split(":")
kernel_name = kernel_name.strip()
else:
repo_id = attn_implementation
kernel_name = None
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
if hasattr(kernel, "flash_attn_varlen_func"):
ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial(
flash_attention_forward, implementation=kernel
)
elif kernel_name is not None:
ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name)
ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[
"flash_attention_2"
]
applicable_attn_implementation = repo_id
except FileNotFoundError as e:
logger.warning_once(
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
"default attention implementation instead (sdpa if available, eager otherwise)."
)
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. Please check "
"the documentation for the correct format, and check that the kernel exports the class and the function correctly."
)
finally:
return applicable_attn_implementation
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
message = (
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '