From 1efcfa9ca48e014e2261db1f93ca95d801e8a342 Mon Sep 17 00:00:00 2001 From: Rupesh K Srivastava Date: Mon, 14 Apr 2025 07:53:27 -0700 Subject: [PATCH] Fix mask handling for flex attention in llama/gemma2/mistral/qwen2 (#37381) * fix BlockMask handling when using flex_attention for llama/mistral/gemma2 * fix attention_mask types * revert type hints and fixup * remove unnecessary assertion --- src/transformers/models/aria/modeling_aria.py | 10 ++++++---- .../models/bloom/modeling_bloom.py | 5 ++--- .../models/chameleon/modeling_chameleon.py | 5 ++--- .../models/codegen/modeling_codegen.py | 5 ++--- .../models/cohere/modeling_cohere.py | 10 ++++++---- .../models/cohere2/modeling_cohere2.py | 18 +++++++++++++++-- src/transformers/models/dbrx/modeling_dbrx.py | 5 ++--- .../deepseek_v3/modeling_deepseek_v3.py | 10 ++++++---- .../models/diffllama/modeling_diffllama.py | 12 ++++++----- src/transformers/models/emu3/modeling_emu3.py | 5 ++--- .../models/gemma/modeling_gemma.py | 10 ++++++---- .../models/gemma2/modeling_gemma2.py | 18 +++++++++++++++-- .../models/gemma2/modular_gemma2.py | 15 ++++++++++++-- .../models/gemma3/modeling_gemma3.py | 18 +++++++++++++++-- src/transformers/models/glm/modeling_glm.py | 10 ++++++---- .../models/gpt_neo/modeling_gpt_neo.py | 5 ++--- .../models/gpt_neox/modeling_gpt_neox.py | 10 ++++++---- .../modeling_gpt_neox_japanese.py | 5 ++--- src/transformers/models/gptj/modeling_gptj.py | 5 ++--- .../models/granite/modeling_granite.py | 10 ++++++---- .../models/granitemoe/modeling_granitemoe.py | 5 ++--- .../modeling_granitemoeshared.py | 5 ++--- .../models/helium/modeling_helium.py | 10 ++++++---- .../models/idefics/modeling_idefics.py | 5 ++--- .../models/jetmoe/modeling_jetmoe.py | 5 ++--- .../models/llama/modeling_llama.py | 12 ++++++----- .../models/longt5/modeling_longt5.py | 5 ++--- src/transformers/models/mimi/modeling_mimi.py | 13 +++++++++++- .../models/mistral/modeling_mistral.py | 20 ++++++++++++++++--- .../models/mistral/modular_mistral.py | 16 ++++++++++++--- .../models/mixtral/modeling_mixtral.py | 20 ++++++++++++++++--- .../models/mllama/modeling_mllama.py | 5 ++--- .../models/moonshine/modeling_moonshine.py | 10 ++++++---- .../models/moshi/modeling_moshi.py | 19 ++++++++++++++++-- src/transformers/models/mt5/modeling_mt5.py | 5 ++--- .../models/nemotron/modeling_nemotron.py | 7 +++---- src/transformers/models/olmo/modeling_olmo.py | 10 ++++++---- .../models/olmo2/modeling_olmo2.py | 10 ++++++---- src/transformers/models/opt/modeling_opt.py | 5 ++--- .../models/persimmon/modeling_persimmon.py | 5 ++--- src/transformers/models/phi/modeling_phi.py | 10 ++++++---- src/transformers/models/phi3/modeling_phi3.py | 18 +++++++++++++++-- .../modeling_phi4_multimodal.py | 13 +++++++++++- .../models/phimoe/modeling_phimoe.py | 13 +++++++++++- .../models/pix2struct/modeling_pix2struct.py | 5 ++--- .../models/pop2piano/modeling_pop2piano.py | 5 ++--- .../models/qwen2/modeling_qwen2.py | 20 ++++++++++++++++--- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 19 ++++++++++++++++-- .../models/qwen2_moe/modeling_qwen2_moe.py | 14 +++++++++++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 12 ++++++++++- .../models/qwen3/modeling_qwen3.py | 20 ++++++++++++++++--- .../models/qwen3_moe/modeling_qwen3_moe.py | 20 ++++++++++++++++--- .../models/stablelm/modeling_stablelm.py | 5 ++--- .../models/starcoder2/modeling_starcoder2.py | 18 +++++++++++++++-- .../modeling_switch_transformers.py | 5 ++--- src/transformers/models/t5/modeling_t5.py | 5 ++--- src/transformers/models/udop/modeling_udop.py | 5 ++--- src/transformers/models/umt5/modeling_umt5.py | 5 ++--- .../models/whisper/modeling_whisper.py | 5 ++--- 59 files changed, 423 insertions(+), 177 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 867ad34b55..abd7d1b8b0 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -781,12 +781,15 @@ ARIA_TEXT_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -983,7 +986,7 @@ class AriaTextModel(AriaTextPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -996,8 +999,7 @@ class AriaTextModel(AriaTextPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 48f89810bd..f9aee9362d 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -742,7 +742,7 @@ class BloomModel(BloomPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -755,8 +755,7 @@ class BloomModel(BloomPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index ad0c255023..05e01c37bb 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1376,7 +1376,7 @@ class ChameleonModel(ChameleonPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1389,8 +1389,7 @@ class ChameleonModel(ChameleonPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 4f06094567..1bc9ceeaa3 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -588,7 +588,7 @@ class CodeGenModel(CodeGenPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -601,8 +601,7 @@ class CodeGenModel(CodeGenPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index e8e6bba237..43c5294788 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -430,12 +430,15 @@ COHERE_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -632,7 +635,7 @@ class CohereModel(CoherePreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -645,8 +648,7 @@ class CohereModel(CoherePreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 65e066f90c..1d0702615e 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -45,6 +46,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_cohere2 import Cohere2Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Cohere2Config" @@ -438,12 +445,15 @@ COHERE2_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -654,7 +664,7 @@ class Cohere2Model(Cohere2PreTrainedModel): @torch.no_grad() def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: HybridCache, @@ -666,6 +676,10 @@ class Cohere2Model(Cohere2PreTrainedModel): # as it doesn't cause dynamic control issues. if self.config._attn_implementation == "flash_attention_2": return attention_mask + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 9f89398685..7da0810645 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1114,7 +1114,7 @@ class DbrxModel(DbrxPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1127,8 +1127,7 @@ class DbrxModel(DbrxPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 586d5251b0..ccddeed663 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -573,12 +573,15 @@ DEEPSEEK_V3_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -777,7 +780,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -790,8 +793,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index f6ff065334..f5bed9bf87 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -678,12 +678,15 @@ DIFFLLAMA_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -880,7 +883,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -893,8 +896,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1259,7 +1261,7 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 1d85fcb639..3b0a4882de 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1468,7 +1468,7 @@ class Emu3TextModel(Emu3PreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1481,8 +1481,7 @@ class Emu3TextModel(Emu3PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f9bcf181c5..fedbac3cea 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -395,12 +395,15 @@ GEMMA_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -599,7 +602,7 @@ class GemmaModel(GemmaPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -612,8 +615,7 @@ class GemmaModel(GemmaPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index ecbfedb2ad..22c4e599c2 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -43,6 +43,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_gemma2 import Gemma2Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -440,12 +447,15 @@ GEMMA2_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -666,7 +676,7 @@ class Gemma2Model(Gemma2PreTrainedModel): @torch.no_grad() def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: HybridCache, @@ -678,6 +688,10 @@ class Gemma2Model(Gemma2PreTrainedModel): # as it doesn't cause dynamic control issues. if self.config._attn_implementation == "flash_attention_2": return attention_mask + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 384f3e0802..e06a701fc5 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -30,7 +30,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import logging +from ...utils import is_torch_flex_attn_available, logging from ..gemma.modeling_gemma import ( GemmaAttention, GemmaForCausalLM, @@ -46,6 +46,13 @@ from ..gemma.modeling_gemma import ( _CHECKPOINT_FOR_DOC = "google/gemma2-7b" + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -535,7 +542,7 @@ class Gemma2Model(GemmaModel): @torch.no_grad() def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: HybridCache, @@ -547,6 +554,10 @@ class Gemma2Model(GemmaModel): # as it doesn't cause dynamic control issues. if self.config._attn_implementation == "flash_attention_2": return attention_mask + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 7d0447a545..e921579326 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -40,6 +40,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -49,6 +50,12 @@ from ..auto import AutoModel, AutoModelForCausalLM from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Gemma3Config" @@ -512,12 +519,15 @@ GEMMA3_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -751,7 +761,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): @torch.no_grad() def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: HybridCache, @@ -763,6 +773,10 @@ class Gemma3TextModel(Gemma3PreTrainedModel): # as it doesn't cause dynamic control issues. if self.config._attn_implementation == "flash_attention_2": return attention_mask + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 044a401402..8b0ccd9c9e 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -413,12 +413,15 @@ GLM_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -615,7 +618,7 @@ class GlmModel(GlmPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -628,8 +631,7 @@ class GlmModel(GlmPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 2f723e4698..45287c025a 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -789,7 +789,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -802,8 +802,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 5a0bc1af2f..e72b82eb8b 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -386,12 +386,15 @@ GPT_NEOX_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -600,7 +603,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -613,8 +616,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 43e860a775..949e26d2b2 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -640,7 +640,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -653,8 +653,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index adfe6c584b..50344787bc 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -890,7 +890,7 @@ class GPTJModel(GPTJPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -903,8 +903,7 @@ class GPTJModel(GPTJPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 445edab65b..1d77b8b7ba 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -414,12 +414,15 @@ GRANITE_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -615,7 +618,7 @@ class GraniteModel(GranitePreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -628,8 +631,7 @@ class GraniteModel(GranitePreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 6217888424..40bff42cdf 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1089,7 +1089,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1102,8 +1102,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 787b805ee8..fe62089a00 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -1034,7 +1034,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1047,8 +1047,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 61bf2f2d09..251beacfad 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -398,12 +398,15 @@ HELIUM_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -600,7 +603,7 @@ class HeliumModel(HeliumPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -613,8 +616,7 @@ class HeliumModel(HeliumPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 770eb90ab1..73d3ac4fab 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1385,7 +1385,7 @@ class IdeficsModel(IdeficsPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1398,8 +1398,7 @@ class IdeficsModel(IdeficsPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 44e0e5b809..bf59fd074e 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1093,7 +1093,7 @@ class JetMoeModel(JetMoePreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1106,8 +1106,7 @@ class JetMoeModel(JetMoePreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index eec1ecfee3..6898f168a2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -403,12 +403,15 @@ LLAMA_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -605,7 +608,7 @@ class LlamaModel(LlamaPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -618,8 +621,7 @@ class LlamaModel(LlamaPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -985,7 +987,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index da8bdca3e7..f44b28f901 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1600,7 +1600,7 @@ class LongT5Stack(LongT5PreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1613,8 +1613,7 @@ class LongT5Stack(LongT5PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 02aedb8c01..a1006ed011 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -33,6 +33,7 @@ from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -42,6 +43,12 @@ from .configuration_mimi import MimiConfig if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -1034,7 +1041,7 @@ class MimiTransformerModel(nn.Module): # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Mimi def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1052,6 +1059,10 @@ class MimiTransformerModel(nn.Module): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 7de6cad370..bf2cccb65b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -32,6 +32,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -39,6 +40,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_mistral import MistralConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" @@ -366,12 +373,15 @@ MISTRAL_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -568,7 +578,7 @@ class MistralModel(MistralPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -586,6 +596,10 @@ class MistralModel(MistralPreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1053,7 +1067,7 @@ class MistralForQuestionAnswering(MistralPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index aae4cb9d52..84062fde3e 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -10,7 +10,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, QuestionAnsweringModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import logging +from ...utils import is_torch_flex_attn_available, logging from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -27,6 +27,12 @@ from ..llama.modeling_llama import ( from .configuration_mistral import MistralConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" @@ -120,7 +126,7 @@ class MistralModel(LlamaModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -138,6 +144,10 @@ class MistralModel(LlamaModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -300,7 +310,7 @@ class MistralForQuestionAnswering(LlamaForQuestionAnswering): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 001692bd75..1f3b7bd0a6 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -55,6 +55,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -62,6 +63,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_mixtral import MixtralConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1" @@ -488,12 +495,15 @@ MIXTRAL_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -697,7 +707,7 @@ class MixtralModel(MixtralPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -715,6 +725,10 @@ class MixtralModel(MixtralPreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1287,7 +1301,7 @@ class MixtralForQuestionAnswering(MixtralPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index a2911a46a0..77bff448ff 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1062,7 +1062,7 @@ class MllamaPreTrainedModel(PreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1075,8 +1075,7 @@ class MllamaPreTrainedModel(PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index aaeb405a0e..97453d6614 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -718,12 +718,15 @@ MOONSHINE_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -961,7 +964,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -974,8 +977,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index dc94efd353..403c27f078 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -43,6 +43,7 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -55,6 +56,12 @@ from .configuration_moshi import MoshiConfig, MoshiDepthConfig if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MoshiConfig" @@ -1261,7 +1268,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Moshi def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1279,6 +1286,10 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1575,7 +1586,7 @@ class MoshiModel(MoshiPreTrainedModel): # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Moshi def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1593,6 +1604,10 @@ class MoshiModel(MoshiPreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 130d58f59b..27badbbeee 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1191,7 +1191,7 @@ class MT5Stack(MT5PreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1204,8 +1204,7 @@ class MT5Stack(MT5PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index a17833912f..4e816c2a5c 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -852,7 +852,7 @@ class NemotronModel(NemotronPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -865,8 +865,7 @@ class NemotronModel(NemotronPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1231,7 +1230,7 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index bf5e80b839..7ac2dd6ad9 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -373,12 +373,15 @@ OLMO_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -575,7 +578,7 @@ class OlmoModel(OlmoPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -588,8 +591,7 @@ class OlmoModel(OlmoPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index e44ea5f62b..d09d47c7c7 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -379,12 +379,15 @@ OLMO2_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -581,7 +584,7 @@ class Olmo2Model(Olmo2PreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -594,8 +597,7 @@ class Olmo2Model(Olmo2PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index d0175055cc..173fe89f6e 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -641,7 +641,7 @@ class OPTDecoder(OPTPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -654,8 +654,7 @@ class OPTDecoder(OPTPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index da8a8d2927..3b492ab851 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -652,7 +652,7 @@ class PersimmonModel(PersimmonPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -665,8 +665,7 @@ class PersimmonModel(PersimmonPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8572b1546a..de1abd9963 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -373,12 +373,15 @@ PHI_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -573,7 +576,7 @@ class PhiModel(PhiPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -586,8 +589,7 @@ class PhiModel(PhiPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 1737d8c3df..2bf32cdcdc 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -47,6 +47,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -54,6 +55,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_phi3 import Phi3Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" @@ -421,12 +428,15 @@ PHI3_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -623,7 +633,7 @@ class Phi3Model(Phi3PreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -641,6 +651,10 @@ class Phi3Model(Phi3PreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 8677ef9dc9..706fb6642b 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -49,6 +49,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, torch_int, @@ -56,6 +57,12 @@ from ...utils import ( from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -1923,7 +1930,7 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1941,6 +1948,10 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index ab8370d8ec..42df5ac9d0 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -48,6 +49,12 @@ from .configuration_phimoe import PhimoeConfig if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) @@ -1170,7 +1177,7 @@ class PhimoeModel(PhimoePreTrainedModel): # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Phimoe def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1188,6 +1195,10 @@ class PhimoeModel(PhimoePreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index da65dbc369..192f4a10b1 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1587,7 +1587,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1600,8 +1600,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 18f6659dc0..0a8b479554 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -1000,7 +1000,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1013,8 +1013,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 147f654652..8c95569a49 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -32,6 +32,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -39,6 +40,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2 import Qwen2Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" @@ -379,12 +386,15 @@ QWEN2_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -581,7 +591,7 @@ class Qwen2Model(Qwen2PreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -599,6 +609,10 @@ class Qwen2Model(Qwen2PreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1066,7 +1080,7 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 5f0a9d003f..388b5f9055 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -41,7 +41,13 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig @@ -52,6 +58,11 @@ if is_flash_attn_available(): if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + logger = logging.get_logger(__name__) @@ -1208,7 +1219,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1226,6 +1237,10 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 6c9cc40ad8..365b561eb9 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -46,6 +46,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -56,6 +57,11 @@ from .configuration_qwen2_moe import Qwen2MoeConfig if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-57B-A14B" @@ -1032,7 +1038,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Qwen2Moe def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1050,6 +1056,10 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1539,7 +1549,7 @@ class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index e172f092d7..96cbb4d3a5 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -40,6 +40,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -49,6 +50,11 @@ from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_varlen_func +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + logger = logging.get_logger(__name__) @@ -1166,7 +1172,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Qwen2VL def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1184,6 +1190,10 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 89f30e78f4..3cc5de2421 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -47,6 +47,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -54,6 +55,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_qwen3 import Qwen3Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B" @@ -406,12 +413,15 @@ QWEN3_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -608,7 +618,7 @@ class Qwen3Model(Qwen3PreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -626,6 +636,10 @@ class Qwen3Model(Qwen3PreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1093,7 +1107,7 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 3462f565e2..6a062420f2 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -50,6 +50,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -57,6 +58,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_qwen3_moe import Qwen3MoeConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-MoE-15B-A2B" @@ -502,12 +509,15 @@ QWEN3_MOE_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -711,7 +721,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -729,6 +739,10 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1301,7 +1315,7 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 69beb543b4..685a7c2372 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -906,7 +906,7 @@ class StableLmModel(StableLmPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -919,8 +919,7 @@ class StableLmModel(StableLmPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 569874aad1..ae16daaa46 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -49,6 +49,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -56,6 +57,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_starcoder2 import Starcoder2Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b" @@ -369,12 +376,15 @@ STARCODER2_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -558,7 +568,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -576,6 +586,10 @@ class Starcoder2Model(Starcoder2PreTrainedModel): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 4a2696a610..3dd8a13987 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1134,7 +1134,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1147,8 +1147,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 8e3cd8d965..e8085a8f45 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1205,7 +1205,7 @@ class T5Stack(T5PreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1218,8 +1218,7 @@ class T5Stack(T5PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 5a3ce836d4..9cd7f98423 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1537,7 +1537,7 @@ class UdopStack(UdopPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1550,8 +1550,7 @@ class UdopStack(UdopPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 37015121cf..44586243c3 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -848,7 +848,7 @@ class UMT5Stack(UMT5PreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -861,8 +861,7 @@ class UMT5Stack(UMT5PreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 2cb995586a..52ad48cc2d 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1375,7 +1375,7 @@ class WhisperDecoder(WhisperPreTrainedModel): # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, @@ -1388,8 +1388,7 @@ class WhisperDecoder(WhisperPreTrainedModel): if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) - if isinstance(attention_mask, BlockMask): - return attention_mask + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail