Proper_flex (#36643)
* proper performant flex attention implementation * wrapper for flex attention to compile only when triggered * wrapper for flex attention to compile only when triggered * attention mask type detection * Update src/transformers/integrations/flex_attention.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * nit * nit * nit * nit * gemma2 support * add citation for torchtune * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update flex_attention.py * nit * nit * nit * reset gemma2 modifications * nit * nit * nit * licencing * apply changes to other models * safe import --------- Co-authored-by: Sung Ching Liu <sunny19981005@outlook.com> Co-authored-by: Sung Ching Liu <22844540+bursteratom@users.noreply.github.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@@ -48,6 +49,12 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "AriaTextConfig"
|
||||
|
||||
@@ -1014,6 +1021,11 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
||||
Reference in New Issue
Block a user