Support for Flash Attention 3 (#38972)

* Support `flash_attn_3`
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged

* Add tests for Flash Attention 2 and 3 parity

* ci fix

* FA2 compatibiity
- `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids`
- Remove bettertransformer check in Flash Attention 3
- Merge tests
- Add licensing

* ci fix

* Test naming consistency

* ci fix

* Deprecation warning for `prepare_fa2_from_position_ids`

* ci fix
This commit is contained in:
EduardDurech
2025-06-25 14:39:27 +02:00
committed by GitHub
parent de98fb25a3
commit a2eb75c891
42 changed files with 698 additions and 262 deletions

View File

@@ -105,6 +105,7 @@ from .utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
is_flash_attn_3_available,
is_kernels_available,
is_offline_mode,
is_optimum_available,
@@ -1957,6 +1958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Flash Attention 2 support
_supports_flash_attn_2 = False
# Flash Attention 3 support
_supports_flash_attn_3 = False
# SDPA support
_supports_sdpa = False
@@ -2247,6 +2251,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_3:
message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
@@ -2282,7 +2288,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
):
sub_config._attn_implementation_internal = curr_attn_implementation
if config._attn_implementation == "flash_attention_2":
if config._attn_implementation == "flash_attention_3":
cls._check_and_enable_flash_attn_3(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
elif config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch_dtype,
@@ -2498,6 +2512,94 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
config._attn_implementation = "flash_attention_2"
return config
@classmethod
def _check_and_enable_flash_attn_3(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, dict[str, int]]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig:
"""
Checks the availability of Flash Attention 3 and compatibility with the current model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_3:
raise ValueError(
f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
)
if not is_flash_attn_3_available():
preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
if importlib.util.find_spec("flash_attn_3") is None:
raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.")
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
if major < 9:
raise ValueError(
f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0."
)
else:
raise ImportError(f"{preface} Flash Attention 3 is not available.")
else:
raise ValueError(
f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device."
)
if torch_dtype is None:
logger.warning_once(
"You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning_once(
"Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
)
if getattr(config, "alibi", False) or getattr(config, "use_alibi", False):
raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
# Check for attention dropout, which is incompatible with FA3
if hasattr(config, "attention_dropout") and config.attention_dropout > 0:
raise ValueError(
f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3."
)
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
check_device_map
and device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
if not hard_check_only:
config._attn_implementation = "flash_attention_3"
return config
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
@@ -4134,7 +4236,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
</Tip>
attn_implementation (`str`, *optional*):
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
> Parameters for big model inference
@@ -5770,6 +5872,7 @@ class AttentionInterface(GeneralInterface):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {
"flash_attention_3": flash_attention_forward,
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"paged_attention": paged_attention_forward,