Add a default value for position_ids in masking_utils (#39310)

* set default

* Update masking_utils.py

* add small test
This commit is contained in:
Cyril Vallez
2025-07-10 18:53:40 +02:00
committed by GitHub
parent bdc8028cb3
commit 571a8c2131
2 changed files with 11 additions and 6 deletions

View File

@@ -607,7 +607,7 @@ class AttentionMaskInterface(GeneralInterface):
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface() ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]: def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor:
""" """
Find the indices of the sequence to which each new query token in the sequence belongs when using packed Find the indices of the sequence to which each new query token in the sequence belongs when using packed
tensor format (i.e. several sequences packed in the same batch dimension). tensor format (i.e. several sequences packed in the same batch dimension).
@@ -721,7 +721,7 @@ def create_causal_mask(
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Optional[Cache], past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor], position_ids: Optional[torch.Tensor] = None,
or_mask_function: Optional[Callable] = None, or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]: ) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -810,7 +810,7 @@ def create_sliding_window_causal_mask(
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Optional[Cache], past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor], position_ids: Optional[torch.Tensor] = None,
or_mask_function: Optional[Callable] = None, or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]: ) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -905,7 +905,7 @@ def create_chunked_causal_mask(
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Optional[Cache], past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor], position_ids: Optional[torch.Tensor] = None,
or_mask_function: Optional[Callable] = None, or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]: ) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -1014,7 +1014,7 @@ def create_masks_for_generate(
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Optional[Cache], past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor], position_ids: Optional[torch.Tensor] = None,
or_mask_function: Optional[Callable] = None, or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None,
**kwargs, **kwargs,

View File

@@ -22,7 +22,7 @@ if is_torch_available():
from torch.nn.attention.flex_attention import create_block_mask from torch.nn.attention.flex_attention import create_block_mask
from transformers import LlamaConfig from transformers import LlamaConfig
from transformers.masking_utils import create_causal_mask from transformers.masking_utils import create_causal_mask, find_packed_sequence_indices
# fmt: off # fmt: off
@@ -130,3 +130,8 @@ class MaskTest(unittest.TestCase):
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared # We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string()) self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
def test_find_packed_sequence_indices(self):
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())