Proper_flex (#36643)

* proper performant flex attention implementation

* wrapper for flex attention to compile only when triggered

* wrapper for flex attention to compile only when triggered

* attention mask type detection

* Update src/transformers/integrations/flex_attention.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* nit

* nit

* nit

* nit

* gemma2 support

* add citation for torchtune

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update flex_attention.py

* nit

* nit

* nit

* reset gemma2 modifications

* nit

* nit

* nit

* licencing

* apply changes to other models

* safe import

---------

Co-authored-by: Sung Ching Liu <sunny19981005@outlook.com>
Co-authored-by: Sung Ching Liu <22844540+bursteratom@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
Arthur
2025-03-11 10:24:12 +01:00
committed by GitHub
parent d8663cb8c5
commit d126f35427
41 changed files with 645 additions and 12 deletions

View File

@@ -34,6 +34,7 @@ from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@@ -48,6 +49,12 @@ if is_torch_available():
from torch import nn
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "AriaTextConfig"
@@ -1014,6 +1021,11 @@ class AriaTextModel(AriaTextPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail