Allow easy registration of custom attention functions (#36889)
* Update modeling_utils.py * style * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * add to init * Update modeling_utils.py * style * update * Update modeling_utils.py * Update modeling_utils.py * style * Add some doc * Update _toctree.yml * readd it for tgi/vllm compat * CIs * CIs
This commit is contained in:
@@ -28,6 +28,7 @@ import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -2081,9 +2082,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
|
||||
)
|
||||
|
||||
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
|
||||
"eager"
|
||||
] + list(ALL_ATTENTION_FUNCTIONS.keys()):
|
||||
if (
|
||||
not isinstance(config._attn_implementation, dict)
|
||||
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
if cls._supports_flash_attn_2:
|
||||
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
@@ -2148,7 +2150,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
|
||||
)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
|
||||
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
config._attn_implementation = requested_attn_implementation
|
||||
elif isinstance(requested_attn_implementation, dict):
|
||||
config._attn_implementation = None
|
||||
@@ -5891,12 +5893,51 @@ def get_disk_only_shard_files(device_map, weight_map):
|
||||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS: Dict[str, Callable] = {}
|
||||
class AttentionInterface(MutableMapping):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
|
||||
"""
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS.update(
|
||||
{
|
||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||
# a new instance is created (in order to locally override a given function)
|
||||
_global_mapping = {
|
||||
"flash_attention_2": flash_attention_forward,
|
||||
"flex_attention": flex_attention_forward,
|
||||
"sdpa": sdpa_attention_forward,
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self._local_mapping = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First check if instance has a local override
|
||||
if key in self._local_mapping:
|
||||
return self._local_mapping[key]
|
||||
return self._global_mapping[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow local update of the default functions without impacting other instances
|
||||
self._local_mapping.update({key: value})
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._local_mapping[key]
|
||||
|
||||
def __iter__(self):
|
||||
# Ensure we use all keys, with the overwritten ones on top
|
||||
return iter(self._global_mapping.update(self._local_mapping))
|
||||
|
||||
def __len__(self):
|
||||
return len(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
|
||||
@classmethod
|
||||
def register(cls, key: str, value: Callable):
|
||||
cls._global_mapping.update({key: value})
|
||||
|
||||
def valid_keys(self) -> List[str]:
|
||||
return list(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
|
||||
|
||||
Reference in New Issue
Block a user