Allow easy registration of custom attention functions (#36889)

* Update modeling_utils.py

* style

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* add to init

* Update modeling_utils.py

* style

* update

* Update modeling_utils.py

* Update modeling_utils.py

* style

* Add some doc

* Update _toctree.yml

* readd it for tgi/vllm compat

* CIs

* CIs
This commit is contained in:
Cyril Vallez
2025-03-26 16:15:06 +01:00
committed by GitHub
parent ad5d40de9c
commit 788e1092e9
6 changed files with 171 additions and 11 deletions

View File

@@ -29,6 +29,8 @@
title: The Transformer model family
- local: attention
title: Attention mechanisms
- local: attention_interface
title: Customizing attention function
title: Models
- sections:
- local: fast_tokenizers

View File

@@ -0,0 +1,106 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Attention Interface
This page describes how to use the `AttentionInterface` in order to register custom attention functions to use with
supported models.
## Customizing attention function
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
as well as `eager`, which is simple matrix multiplication without any optimization on top.
This is the setting you can usually choose when instantiating a model:
```python
from transformers import AutoModelForCausalLM
model_id = "meta-llama/Llama-3.2-1B
# Here, using flash attention as an example
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
```
But what if you wanted to create your own attention function? Or simply play around with existing ones, adding
a few statements here and there? You can now do so with the `AttentionInterface`! Here is an example:
```python
from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
import torch
model_id = "meta-llama/Llama-3.2-1B
def my_new_sdpa(*args, **kwargs):
print("I just entered the attention computation")
return sdpa_attention_forward(*args, **kwargs)
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa")
# Try running the forward with the new attention function
model(torch.ones(1, 5, dtype=int))
```
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times.
## Dynamically switching attention function
You could dynamically change the model's attention function as well, by overriding the `config._attn_implementation` field:
```python
# Back to use original sdpa implementation
model.config._attn_implementation = "sdpa"
model(torch.ones(1, 5, dtype=int))
```
and it will stop printing the statements, as it now uses the `sdpa` attention.
This allows to quickly change attention function, without needing to reload the model!
## What about new args needed in my custom function?
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
`AttentionInterface` propagates kwargs all the way to the Attention layers, and to the attention function used. That way,
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
```python
from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
import torch
def custom_attention(
module: torch.nn.Module, # required arg
query: torch.Tensor, # required arg
key: torch.Tensor, # required arg
value: torch.Tensor, # required arg
attention_mask: Optional[torch.Tensor], # required arg
a_new_kwargs = None, # You can now add as many kwargs as you need
another_new_kwargs = None, # You can now add as many kwargs as you need
**kwargs, # You need to accept **kwargs as models will pass other args
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]
... # do your magic!
return attn_output, attn_weights # attn_weights are optional here
AttentionInterface.register("custom", custom_attention)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
# Forward pass with the new kwargs
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
```
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!

View File

@@ -16,10 +16,14 @@ rendered properly in your Markdown viewer.
# Custom Layers and Utilities
This page lists all the custom layers used by the library, as well as the utility functions it provides for modeling.
This page lists all the custom layers used by the library, as well as the utility functions and classes it provides for modeling.
Most of those are only useful if you are studying the code of the models in the library.
## Attention Functions
[[autodoc]] AttentionInterface
- register
## Pytorch custom modules

View File

@@ -1482,7 +1482,7 @@ else:
_import_structure["modeling_flash_attention_utils"] = []
_import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
_import_structure["modeling_utils"] = ["PreTrainedModel"]
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
# PyTorch models structure
@@ -6727,7 +6727,7 @@ if TYPE_CHECKING:
model_addition_debugger_context,
)
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
from .modeling_utils import PreTrainedModel
from .modeling_utils import AttentionInterface, PreTrainedModel
from .models.albert import (
AlbertForMaskedLM,
AlbertForMultipleChoice,

View File

@@ -28,6 +28,7 @@ import shutil
import tempfile
import warnings
from collections import defaultdict
from collections.abc import MutableMapping
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -2081,9 +2082,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
"eager"
] + list(ALL_ATTENTION_FUNCTIONS.keys()):
if (
not isinstance(config._attn_implementation, dict)
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
@@ -2148,7 +2150,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
@@ -5891,12 +5893,51 @@ def get_disk_only_shard_files(device_map, weight_map):
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
ALL_ATTENTION_FUNCTIONS: Dict[str, Callable] = {}
class AttentionInterface(MutableMapping):
"""
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
"""
ALL_ATTENTION_FUNCTIONS.update(
{
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"sdpa": sdpa_attention_forward,
}
)
def __init__(self):
self._local_mapping = {}
def __getitem__(self, key):
# First check if instance has a local override
if key in self._local_mapping:
return self._local_mapping[key]
return self._global_mapping[key]
def __setitem__(self, key, value):
# Allow local update of the default functions without impacting other instances
self._local_mapping.update({key: value})
def __delitem__(self, key):
del self._local_mapping[key]
def __iter__(self):
# Ensure we use all keys, with the overwritten ones on top
return iter(self._global_mapping.update(self._local_mapping))
def __len__(self):
return len(self._global_mapping.keys() | self._local_mapping.keys())
@classmethod
def register(cls, key: str, value: Callable):
cls._global_mapping.update({key: value})
def valid_keys(self) -> List[str]:
return list(self._global_mapping.keys() | self._local_mapping.keys())
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()

View File

@@ -549,6 +549,13 @@ def model_addition_debugger_context(*args, **kwargs):
ROPE_INIT_FUNCTIONS = None
class AttentionInterface(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]