Add a default value for position_ids in masking_utils (#39310)
* set default * Update masking_utils.py * add small test
This commit is contained in:
@@ -607,7 +607,7 @@ class AttentionMaskInterface(GeneralInterface):
|
|||||||
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
|
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
|
||||||
|
|
||||||
|
|
||||||
def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]:
|
def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
|
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
|
||||||
tensor format (i.e. several sequences packed in the same batch dimension).
|
tensor format (i.e. several sequences packed in the same batch dimension).
|
||||||
@@ -721,7 +721,7 @@ def create_causal_mask(
|
|||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_key_values: Optional[Cache],
|
past_key_values: Optional[Cache],
|
||||||
position_ids: Optional[torch.Tensor],
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
or_mask_function: Optional[Callable] = None,
|
or_mask_function: Optional[Callable] = None,
|
||||||
and_mask_function: Optional[Callable] = None,
|
and_mask_function: Optional[Callable] = None,
|
||||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||||
@@ -810,7 +810,7 @@ def create_sliding_window_causal_mask(
|
|||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_key_values: Optional[Cache],
|
past_key_values: Optional[Cache],
|
||||||
position_ids: Optional[torch.Tensor],
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
or_mask_function: Optional[Callable] = None,
|
or_mask_function: Optional[Callable] = None,
|
||||||
and_mask_function: Optional[Callable] = None,
|
and_mask_function: Optional[Callable] = None,
|
||||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||||
@@ -905,7 +905,7 @@ def create_chunked_causal_mask(
|
|||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_key_values: Optional[Cache],
|
past_key_values: Optional[Cache],
|
||||||
position_ids: Optional[torch.Tensor],
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
or_mask_function: Optional[Callable] = None,
|
or_mask_function: Optional[Callable] = None,
|
||||||
and_mask_function: Optional[Callable] = None,
|
and_mask_function: Optional[Callable] = None,
|
||||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||||
@@ -1014,7 +1014,7 @@ def create_masks_for_generate(
|
|||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_key_values: Optional[Cache],
|
past_key_values: Optional[Cache],
|
||||||
position_ids: Optional[torch.Tensor],
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
or_mask_function: Optional[Callable] = None,
|
or_mask_function: Optional[Callable] = None,
|
||||||
and_mask_function: Optional[Callable] = None,
|
and_mask_function: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ if is_torch_available():
|
|||||||
from torch.nn.attention.flex_attention import create_block_mask
|
from torch.nn.attention.flex_attention import create_block_mask
|
||||||
|
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
from transformers.masking_utils import create_causal_mask
|
from transformers.masking_utils import create_causal_mask, find_packed_sequence_indices
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@@ -130,3 +130,8 @@ class MaskTest(unittest.TestCase):
|
|||||||
|
|
||||||
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
|
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
|
||||||
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
|
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
|
||||||
|
|
||||||
|
def test_find_packed_sequence_indices(self):
|
||||||
|
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
|
||||||
|
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
|
||||||
|
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
|
||||||
|
|||||||
Reference in New Issue
Block a user