From 788e1092e90eacfd1afb14166a68574ef58cafa9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 26 Mar 2025 16:15:06 +0100 Subject: [PATCH] Allow easy registration of custom attention functions (#36889) * Update modeling_utils.py * style * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * add to init * Update modeling_utils.py * style * update * Update modeling_utils.py * Update modeling_utils.py * style * Add some doc * Update _toctree.yml * readd it for tgi/vllm compat * CIs * CIs --- docs/source/en/_toctree.yml | 2 + docs/source/en/attention_interface.md | 106 +++++++++++++++++++++ docs/source/en/internal/modeling_utils.md | 6 +- src/transformers/__init__.py | 4 +- src/transformers/modeling_utils.py | 57 +++++++++-- src/transformers/utils/dummy_pt_objects.py | 7 ++ 6 files changed, 171 insertions(+), 11 deletions(-) create mode 100644 docs/source/en/attention_interface.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index bcd054113c..034ba00abc 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -29,6 +29,8 @@ title: The Transformer model family - local: attention title: Attention mechanisms + - local: attention_interface + title: Customizing attention function title: Models - sections: - local: fast_tokenizers diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md new file mode 100644 index 0000000000..275a3d88e3 --- /dev/null +++ b/docs/source/en/attention_interface.md @@ -0,0 +1,106 @@ + + +# Attention Interface + +This page describes how to use the `AttentionInterface` in order to register custom attention functions to use with +supported models. + +## Customizing attention function + +Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping. +By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), +[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) +as well as `eager`, which is simple matrix multiplication without any optimization on top. +This is the setting you can usually choose when instantiating a model: + +```python +from transformers import AutoModelForCausalLM + +model_id = "meta-llama/Llama-3.2-1B + +# Here, using flash attention as an example +model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2") +``` + +But what if you wanted to create your own attention function? Or simply play around with existing ones, adding +a few statements here and there? You can now do so with the `AttentionInterface`! Here is an example: + +```python +from transformers import AutoModelForCausalLM, AttentionInterface +from transformers.integrations.sdpa_attention import sdpa_attention_forward +import torch + +model_id = "meta-llama/Llama-3.2-1B + +def my_new_sdpa(*args, **kwargs): + print("I just entered the attention computation") + return sdpa_attention_forward(*args, **kwargs) + +AttentionInterface.register("my_new_sdpa", my_new_sdpa) + +model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa") +# Try running the forward with the new attention function +model(torch.ones(1, 5, dtype=int)) +``` + +You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times. + +## Dynamically switching attention function + +You could dynamically change the model's attention function as well, by overriding the `config._attn_implementation` field: + +```python +# Back to use original sdpa implementation +model.config._attn_implementation = "sdpa" + +model(torch.ones(1, 5, dtype=int)) +``` + +and it will stop printing the statements, as it now uses the `sdpa` attention. +This allows to quickly change attention function, without needing to reload the model! + +## What about new args needed in my custom function? + +But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the +`AttentionInterface` propagates kwargs all the way to the Attention layers, and to the attention function used. That way, +you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e. + +```python +from transformers import AutoModelForCausalLM, AttentionInterface +from transformers.integrations.sdpa_attention import sdpa_attention_forward +import torch + +def custom_attention( + module: torch.nn.Module, # required arg + query: torch.Tensor, # required arg + key: torch.Tensor, # required arg + value: torch.Tensor, # required arg + attention_mask: Optional[torch.Tensor], # required arg + a_new_kwargs = None, # You can now add as many kwargs as you need + another_new_kwargs = None, # You can now add as many kwargs as you need + **kwargs, # You need to accept **kwargs as models will pass other args +) -> Tuple[torch.Tensor, Optional[torch.Tensor]] + ... # do your magic! + return attn_output, attn_weights # attn_weights are optional here + +AttentionInterface.register("custom", custom_attention) + +model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom") +# Forward pass with the new kwargs +model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...) +``` + +If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)! \ No newline at end of file diff --git a/docs/source/en/internal/modeling_utils.md b/docs/source/en/internal/modeling_utils.md index afc8123558..2fccc772cf 100644 --- a/docs/source/en/internal/modeling_utils.md +++ b/docs/source/en/internal/modeling_utils.md @@ -16,10 +16,14 @@ rendered properly in your Markdown viewer. # Custom Layers and Utilities -This page lists all the custom layers used by the library, as well as the utility functions it provides for modeling. +This page lists all the custom layers used by the library, as well as the utility functions and classes it provides for modeling. Most of those are only useful if you are studying the code of the models in the library. +## Attention Functions + +[[autodoc]] AttentionInterface + - register ## Pytorch custom modules diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e8da536747..2db51e4770 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1482,7 +1482,7 @@ else: _import_structure["modeling_flash_attention_utils"] = [] _import_structure["modeling_outputs"] = [] _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"] - _import_structure["modeling_utils"] = ["PreTrainedModel"] + _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"] # PyTorch models structure @@ -6727,7 +6727,7 @@ if TYPE_CHECKING: model_addition_debugger_context, ) from .modeling_rope_utils import ROPE_INIT_FUNCTIONS - from .modeling_utils import PreTrainedModel + from .modeling_utils import AttentionInterface, PreTrainedModel from .models.albert import ( AlbertForMaskedLM, AlbertForMultipleChoice, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index be79f5b327..f9b3faf480 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -28,6 +28,7 @@ import shutil import tempfile import warnings from collections import defaultdict +from collections.abc import MutableMapping from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -2081,9 +2082,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) - if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ - "eager" - ] + list(ALL_ATTENTION_FUNCTIONS.keys()): + if ( + not isinstance(config._attn_implementation, dict) + and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() + ): message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' @@ -2148,7 +2150,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) - elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()): + elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys(): config._attn_implementation = requested_attn_implementation elif isinstance(requested_attn_implementation, dict): config._attn_implementation = None @@ -5891,12 +5893,51 @@ def get_disk_only_shard_files(device_map, weight_map): return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] -ALL_ATTENTION_FUNCTIONS: Dict[str, Callable] = {} +class AttentionInterface(MutableMapping): + """ + Dict-like object keeping track of allowed attention functions. You can easily add a new attention function + with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`, + it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance. + """ -ALL_ATTENTION_FUNCTIONS.update( - { + # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if + # a new instance is created (in order to locally override a given function) + _global_mapping = { "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, "sdpa": sdpa_attention_forward, } -) + + def __init__(self): + self._local_mapping = {} + + def __getitem__(self, key): + # First check if instance has a local override + if key in self._local_mapping: + return self._local_mapping[key] + return self._global_mapping[key] + + def __setitem__(self, key, value): + # Allow local update of the default functions without impacting other instances + self._local_mapping.update({key: value}) + + def __delitem__(self, key): + del self._local_mapping[key] + + def __iter__(self): + # Ensure we use all keys, with the overwritten ones on top + return iter(self._global_mapping.update(self._local_mapping)) + + def __len__(self): + return len(self._global_mapping.keys() | self._local_mapping.keys()) + + @classmethod + def register(cls, key: str, value: Callable): + cls._global_mapping.update({key: value}) + + def valid_keys(self) -> List[str]: + return list(self._global_mapping.keys() | self._local_mapping.keys()) + + +# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones +ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index a7051cffca..744a316161 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -549,6 +549,13 @@ def model_addition_debugger_context(*args, **kwargs): ROPE_INIT_FUNCTIONS = None +class AttentionInterface(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PreTrainedModel(metaclass=DummyObject): _backends = ["torch"]