Kernels flash attn (#39474)
* use partial to wrap around `transformers` utils! * try to refactor? * revert one wrong change * just a nit * push * reverter watever was wrong! * some nits * fixes when there is no attention mask * bring the licence back * some fixes * nit * style * remove prints * correct dtype * fa flags for testing * update * use paged attention if requested! * updates * a clone was needed, not sure why * automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute * simplify and improve? * flash attention is kinda broken on recent cuda version so allow the opportunity to use something else * fix! * protect kernels import * update * properly parse generation config being passed * revert and update * add two tests * some fixes * fix test FA2 * takes comment into account * fixup * revert changes * revert the clone, it is only needed because the metal kernel is not doing it? * [docs] update attention implementation and cache docs (#39547) * update docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * applu suggestions --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix mps on our side for now * Update src/transformers/integrations/flash_paged.py * no qa --------- Co-authored-by: Vasqu <antonprogamer@gmail.com> Co-authored-by: Raushan Turganbay <raushan@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -72,6 +72,7 @@ from .integrations.tensor_parallel import (
|
||||
verify_tp_plan,
|
||||
)
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
@@ -2785,30 +2786,38 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
None to sdpa (to potentially eager).
|
||||
"""
|
||||
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
|
||||
if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation):
|
||||
if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = applicable_attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
if ":" in applicable_attn_implementation:
|
||||
repo_id, kernel_name = attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
else:
|
||||
repo_id = attn_implementation
|
||||
kernel_name = None
|
||||
repo_id = repo_id.strip()
|
||||
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
||||
applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
if hasattr(kernel, "flash_attn_varlen_func"):
|
||||
ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial(
|
||||
flash_attention_forward, implementation=kernel
|
||||
)
|
||||
elif kernel_name is not None:
|
||||
ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[
|
||||
"flash_attention_2"
|
||||
]
|
||||
applicable_attn_implementation = repo_id
|
||||
except FileNotFoundError as e:
|
||||
logger.warning_once(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||
)
|
||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. Please check "
|
||||
"the documentation for the correct format, and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
finally:
|
||||
return applicable_attn_implementation
|
||||
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
message = (
|
||||
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
||||
|
||||
Reference in New Issue
Block a user