Fix mask handling for flex attention in llama/gemma2/mistral/qwen2 (#37381)
* fix BlockMask handling when using flex_attention for llama/mistral/gemma2 * fix attention_mask types * revert type hints and fixup * remove unnecessary assertion
This commit is contained in:
committed by
GitHub
parent
86064035f0
commit
1efcfa9ca4
@@ -781,12 +781,15 @@ ARIA_TEXT_INPUTS_DOCSTRING = r"""
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
|
||||
but you can also pass a `BlockMask` object directly here.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
@@ -983,7 +986,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
@@ -996,8 +999,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
||||
Reference in New Issue
Block a user