Add a default value for position_ids in masking_utils (#39310)
* set default * Update masking_utils.py * add small test
This commit is contained in:
@@ -22,7 +22,7 @@ if is_torch_available():
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
|
||||
from transformers import LlamaConfig
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.masking_utils import create_causal_mask, find_packed_sequence_indices
|
||||
|
||||
|
||||
# fmt: off
|
||||
@@ -130,3 +130,8 @@ class MaskTest(unittest.TestCase):
|
||||
|
||||
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
|
||||
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
|
||||
|
||||
def test_find_packed_sequence_indices(self):
|
||||
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
|
||||
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
|
||||
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
|
||||
|
||||
Reference in New Issue
Block a user