Allow easy registration of custom attention functions (#36889)

* Update modeling_utils.py

* style

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* add to init

* Update modeling_utils.py

* style

* update

* Update modeling_utils.py

* Update modeling_utils.py

* style

* Add some doc

* Update _toctree.yml

* readd it for tgi/vllm compat

* CIs

* CIs
This commit is contained in:
Cyril Vallez
2025-03-26 16:15:06 +01:00
committed by GitHub
parent ad5d40de9c
commit 788e1092e9
6 changed files with 171 additions and 11 deletions

View File

@@ -28,6 +28,7 @@ import shutil
import tempfile
import warnings
from collections import defaultdict
from collections.abc import MutableMapping
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -2081,9 +2082,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
"eager"
] + list(ALL_ATTENTION_FUNCTIONS.keys()):
if (
not isinstance(config._attn_implementation, dict)
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
@@ -2148,7 +2150,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
@@ -5891,12 +5893,51 @@ def get_disk_only_shard_files(device_map, weight_map):
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
ALL_ATTENTION_FUNCTIONS: Dict[str, Callable] = {}
class AttentionInterface(MutableMapping):
"""
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
"""
ALL_ATTENTION_FUNCTIONS.update(
{
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"sdpa": sdpa_attention_forward,
}
)
def __init__(self):
self._local_mapping = {}
def __getitem__(self, key):
# First check if instance has a local override
if key in self._local_mapping:
return self._local_mapping[key]
return self._global_mapping[key]
def __setitem__(self, key, value):
# Allow local update of the default functions without impacting other instances
self._local_mapping.update({key: value})
def __delitem__(self, key):
del self._local_mapping[key]
def __iter__(self):
# Ensure we use all keys, with the overwritten ones on top
return iter(self._global_mapping.update(self._local_mapping))
def __len__(self):
return len(self._global_mapping.keys() | self._local_mapping.keys())
@classmethod
def register(cls, key: str, value: Callable):
cls._global_mapping.update({key: value})
def valid_keys(self) -> List[str]:
return list(self._global_mapping.keys() | self._local_mapping.keys())
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()