Fix mask handling for flex attention in llama/gemma2/mistral/qwen2 (#37381)

* fix BlockMask handling when using flex_attention for llama/mistral/gemma2

* fix attention_mask types

* revert type hints and fixup

* remove unnecessary assertion
This commit is contained in:
Rupesh K Srivastava
2025-04-14 07:53:27 -07:00
committed by GitHub
parent 86064035f0
commit 1efcfa9ca4
59 changed files with 423 additions and 177 deletions

View File

@@ -781,12 +781,15 @@ ARIA_TEXT_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
but you can also pass a `BlockMask` object directly here.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
@@ -983,7 +986,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
@@ -996,8 +999,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail