diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 8d5aab9f13..10f1a394d5 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -607,7 +607,7 @@ class AttentionMaskInterface(GeneralInterface): ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface() -def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]: +def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor: """ Find the indices of the sequence to which each new query token in the sequence belongs when using packed tensor format (i.e. several sequences packed in the same batch dimension). @@ -721,7 +721,7 @@ def create_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - position_ids: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor] = None, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, BlockMask]]: @@ -810,7 +810,7 @@ def create_sliding_window_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - position_ids: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor] = None, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, BlockMask]]: @@ -905,7 +905,7 @@ def create_chunked_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - position_ids: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor] = None, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, BlockMask]]: @@ -1014,7 +1014,7 @@ def create_masks_for_generate( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - position_ids: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor] = None, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, **kwargs, diff --git a/tests/utils/test_masking_utils.py b/tests/utils/test_masking_utils.py index 3b162e0b08..11d7e7e72b 100644 --- a/tests/utils/test_masking_utils.py +++ b/tests/utils/test_masking_utils.py @@ -22,7 +22,7 @@ if is_torch_available(): from torch.nn.attention.flex_attention import create_block_mask from transformers import LlamaConfig - from transformers.masking_utils import create_causal_mask + from transformers.masking_utils import create_causal_mask, find_packed_sequence_indices # fmt: off @@ -130,3 +130,8 @@ class MaskTest(unittest.TestCase): # We compatre the str representations, as the BlockMask objects themselves cannot easily be compared self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string()) + + def test_find_packed_sequence_indices(self): + position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) + EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) + self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())