Allow easy registration of custom attention functions (#36889)
* Update modeling_utils.py * style * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * add to init * Update modeling_utils.py * style * update * Update modeling_utils.py * Update modeling_utils.py * style * Add some doc * Update _toctree.yml * readd it for tgi/vllm compat * CIs * CIs
This commit is contained in:
@@ -29,6 +29,8 @@
|
||||
title: The Transformer model family
|
||||
- local: attention
|
||||
title: Attention mechanisms
|
||||
- local: attention_interface
|
||||
title: Customizing attention function
|
||||
title: Models
|
||||
- sections:
|
||||
- local: fast_tokenizers
|
||||
|
||||
106
docs/source/en/attention_interface.md
Normal file
106
docs/source/en/attention_interface.md
Normal file
@@ -0,0 +1,106 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Attention Interface
|
||||
|
||||
This page describes how to use the `AttentionInterface` in order to register custom attention functions to use with
|
||||
supported models.
|
||||
|
||||
## Customizing attention function
|
||||
|
||||
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
|
||||
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
|
||||
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
|
||||
as well as `eager`, which is simple matrix multiplication without any optimization on top.
|
||||
This is the setting you can usually choose when instantiating a model:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B
|
||||
|
||||
# Here, using flash attention as an example
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
|
||||
```
|
||||
|
||||
But what if you wanted to create your own attention function? Or simply play around with existing ones, adding
|
||||
a few statements here and there? You can now do so with the `AttentionInterface`! Here is an example:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AttentionInterface
|
||||
from transformers.integrations.sdpa_attention import sdpa_attention_forward
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B
|
||||
|
||||
def my_new_sdpa(*args, **kwargs):
|
||||
print("I just entered the attention computation")
|
||||
return sdpa_attention_forward(*args, **kwargs)
|
||||
|
||||
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa")
|
||||
# Try running the forward with the new attention function
|
||||
model(torch.ones(1, 5, dtype=int))
|
||||
```
|
||||
|
||||
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times.
|
||||
|
||||
## Dynamically switching attention function
|
||||
|
||||
You could dynamically change the model's attention function as well, by overriding the `config._attn_implementation` field:
|
||||
|
||||
```python
|
||||
# Back to use original sdpa implementation
|
||||
model.config._attn_implementation = "sdpa"
|
||||
|
||||
model(torch.ones(1, 5, dtype=int))
|
||||
```
|
||||
|
||||
and it will stop printing the statements, as it now uses the `sdpa` attention.
|
||||
This allows to quickly change attention function, without needing to reload the model!
|
||||
|
||||
## What about new args needed in my custom function?
|
||||
|
||||
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
|
||||
`AttentionInterface` propagates kwargs all the way to the Attention layers, and to the attention function used. That way,
|
||||
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AttentionInterface
|
||||
from transformers.integrations.sdpa_attention import sdpa_attention_forward
|
||||
import torch
|
||||
|
||||
def custom_attention(
|
||||
module: torch.nn.Module, # required arg
|
||||
query: torch.Tensor, # required arg
|
||||
key: torch.Tensor, # required arg
|
||||
value: torch.Tensor, # required arg
|
||||
attention_mask: Optional[torch.Tensor], # required arg
|
||||
a_new_kwargs = None, # You can now add as many kwargs as you need
|
||||
another_new_kwargs = None, # You can now add as many kwargs as you need
|
||||
**kwargs, # You need to accept **kwargs as models will pass other args
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]
|
||||
... # do your magic!
|
||||
return attn_output, attn_weights # attn_weights are optional here
|
||||
|
||||
AttentionInterface.register("custom", custom_attention)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
|
||||
# Forward pass with the new kwargs
|
||||
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
|
||||
```
|
||||
|
||||
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
|
||||
@@ -16,10 +16,14 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Custom Layers and Utilities
|
||||
|
||||
This page lists all the custom layers used by the library, as well as the utility functions it provides for modeling.
|
||||
This page lists all the custom layers used by the library, as well as the utility functions and classes it provides for modeling.
|
||||
|
||||
Most of those are only useful if you are studying the code of the models in the library.
|
||||
|
||||
## Attention Functions
|
||||
|
||||
[[autodoc]] AttentionInterface
|
||||
- register
|
||||
|
||||
## Pytorch custom modules
|
||||
|
||||
|
||||
@@ -1482,7 +1482,7 @@ else:
|
||||
_import_structure["modeling_flash_attention_utils"] = []
|
||||
_import_structure["modeling_outputs"] = []
|
||||
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
|
||||
_import_structure["modeling_utils"] = ["PreTrainedModel"]
|
||||
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
|
||||
|
||||
# PyTorch models structure
|
||||
|
||||
@@ -6727,7 +6727,7 @@ if TYPE_CHECKING:
|
||||
model_addition_debugger_context,
|
||||
)
|
||||
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .modeling_utils import AttentionInterface, PreTrainedModel
|
||||
from .models.albert import (
|
||||
AlbertForMaskedLM,
|
||||
AlbertForMultipleChoice,
|
||||
|
||||
@@ -28,6 +28,7 @@ import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -2081,9 +2082,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
|
||||
)
|
||||
|
||||
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
|
||||
"eager"
|
||||
] + list(ALL_ATTENTION_FUNCTIONS.keys()):
|
||||
if (
|
||||
not isinstance(config._attn_implementation, dict)
|
||||
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
if cls._supports_flash_attn_2:
|
||||
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
@@ -2148,7 +2150,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
|
||||
)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
|
||||
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
config._attn_implementation = requested_attn_implementation
|
||||
elif isinstance(requested_attn_implementation, dict):
|
||||
config._attn_implementation = None
|
||||
@@ -5891,12 +5893,51 @@ def get_disk_only_shard_files(device_map, weight_map):
|
||||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS: Dict[str, Callable] = {}
|
||||
class AttentionInterface(MutableMapping):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
|
||||
"""
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS.update(
|
||||
{
|
||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||
# a new instance is created (in order to locally override a given function)
|
||||
_global_mapping = {
|
||||
"flash_attention_2": flash_attention_forward,
|
||||
"flex_attention": flex_attention_forward,
|
||||
"sdpa": sdpa_attention_forward,
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self._local_mapping = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First check if instance has a local override
|
||||
if key in self._local_mapping:
|
||||
return self._local_mapping[key]
|
||||
return self._global_mapping[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow local update of the default functions without impacting other instances
|
||||
self._local_mapping.update({key: value})
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._local_mapping[key]
|
||||
|
||||
def __iter__(self):
|
||||
# Ensure we use all keys, with the overwritten ones on top
|
||||
return iter(self._global_mapping.update(self._local_mapping))
|
||||
|
||||
def __len__(self):
|
||||
return len(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
|
||||
@classmethod
|
||||
def register(cls, key: str, value: Callable):
|
||||
cls._global_mapping.update({key: value})
|
||||
|
||||
def valid_keys(self) -> List[str]:
|
||||
return list(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
|
||||
|
||||
@@ -549,6 +549,13 @@ def model_addition_debugger_context(*args, **kwargs):
|
||||
ROPE_INIT_FUNCTIONS = None
|
||||
|
||||
|
||||
class AttentionInterface(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user