Add a default value for position_ids in masking_utils (#39310)
* set default * Update masking_utils.py * add small test
This commit is contained in:
@@ -607,7 +607,7 @@ class AttentionMaskInterface(GeneralInterface):
|
||||
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
|
||||
|
||||
|
||||
def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
|
||||
tensor format (i.e. several sequences packed in the same batch dimension).
|
||||
@@ -721,7 +721,7 @@ def create_causal_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
position_ids: Optional[torch.Tensor],
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
@@ -810,7 +810,7 @@ def create_sliding_window_causal_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
position_ids: Optional[torch.Tensor],
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
@@ -905,7 +905,7 @@ def create_chunked_causal_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
position_ids: Optional[torch.Tensor],
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
@@ -1014,7 +1014,7 @@ def create_masks_for_generate(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
position_ids: Optional[torch.Tensor],
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
|
||||
@@ -22,7 +22,7 @@ if is_torch_available():
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
|
||||
from transformers import LlamaConfig
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.masking_utils import create_causal_mask, find_packed_sequence_indices
|
||||
|
||||
|
||||
# fmt: off
|
||||
@@ -130,3 +130,8 @@ class MaskTest(unittest.TestCase):
|
||||
|
||||
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
|
||||
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
|
||||
|
||||
def test_find_packed_sequence_indices(self):
|
||||
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
|
||||
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
|
||||
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
|
||||
|
||||
Reference in New Issue
Block a user