Add a default value for position_ids in masking_utils (#39310)

* set default

* Update masking_utils.py

* add small test
This commit is contained in:
Cyril Vallez
2025-07-10 18:53:40 +02:00
committed by GitHub
parent bdc8028cb3
commit 571a8c2131
2 changed files with 11 additions and 6 deletions

View File

@@ -607,7 +607,7 @@ class AttentionMaskInterface(GeneralInterface):
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]:
def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor:
"""
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
tensor format (i.e. several sequences packed in the same batch dimension).
@@ -721,7 +721,7 @@ def create_causal_mask(
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -810,7 +810,7 @@ def create_sliding_window_causal_mask(
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -905,7 +905,7 @@ def create_chunked_causal_mask(
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -1014,7 +1014,7 @@ def create_masks_for_generate(
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
**kwargs,

View File

@@ -22,7 +22,7 @@ if is_torch_available():
from torch.nn.attention.flex_attention import create_block_mask
from transformers import LlamaConfig
from transformers.masking_utils import create_causal_mask
from transformers.masking_utils import create_causal_mask, find_packed_sequence_indices
# fmt: off
@@ -130,3 +130,8 @@ class MaskTest(unittest.TestCase):
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
def test_find_packed_sequence_indices(self):
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())