add FlashAttentionKwargs and seq_idx to flat collator (#36456)
* add flash attn kwargs to flattening collator
* add return_seq_idx option
* doc string edits
* cleaner max len updates
* various fixes
* temp testing code
* return int32 seq_idx and FlashAttnKwargs
* DataCollatorIntegrationTest impl
* fix batch dims and dtypes
* fill out remaining collator tests
* test name change and fmt
* rm unused var
* fmt
* minor change
* fmt
* add missing pos_ids check
* consistent {np,pt,tf} tests
* split pt tests into 3, like np/tf tests
* mv comment, rename fa test
* remove batch dim comment
* simply wrapping
* compute cu_seq_len/max_length once
* fmt
* remove tf code
* rm warning
* move separator_id back to 2nd pos
* use cleaner lists in tests
* ret -> batch
* fmt
* attr ordering
* use py ints for max_length_{k,q}
This commit is contained in:
@@ -126,6 +126,104 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
||||
|
||||
def test_data_collator_with_flattening(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt")
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
self.assertIn("position_ids", batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
def test_data_collator_with_flattening_flash_attn_kwargs(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
for expected_key in [
|
||||
"position_ids",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
]:
|
||||
self.assertIn(expected_key, batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
self.assertEqual(batch["cu_seq_lens_k"].shape, torch.Size([4]))
|
||||
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
|
||||
self.assertEqual(batch["cu_seq_lens_q"].shape, torch.Size([4]))
|
||||
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
|
||||
# The flash attn max_length_{k,q} are simple python ints
|
||||
self.assertEqual(batch["max_length_k"], 7)
|
||||
self.assertEqual(batch["max_length_q"], 7)
|
||||
|
||||
def test_data_collator_with_flattening_seq_idx(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_seq_idx=True)
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
for expected_key in [
|
||||
"position_ids",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertIn(expected_key, batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
|
||||
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
|
||||
|
||||
def test_data_collator_for_token_classification(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [
|
||||
@@ -1803,15 +1901,97 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="np")
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
self.assertIn("position_ids", batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertNotIn("attention_mask", batch)
|
||||
self.assertIn("position_ids", batch)
|
||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
def test_data_collator_with_flattening_flash_attn_kwargs(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="np", return_flash_attn_kwargs=True)
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
for expected_key in [
|
||||
"position_ids",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
]:
|
||||
self.assertIn(expected_key, batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
self.assertEqual(batch["cu_seq_lens_k"].shape, (4,))
|
||||
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
|
||||
self.assertEqual(batch["cu_seq_lens_q"].shape, (4,))
|
||||
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
|
||||
# The flash attn max_length_{k,q} are simple python ints
|
||||
self.assertEqual(batch["max_length_k"], 7)
|
||||
self.assertEqual(batch["max_length_q"], 7)
|
||||
|
||||
def test_data_collator_with_flattening_seq_idx(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="np", return_seq_idx=True)
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
for expected_key in [
|
||||
"position_ids",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertIn(expected_key, batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
|
||||
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
|
||||
|
||||
def test_data_collator_for_token_classification(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [
|
||||
|
||||
Reference in New Issue
Block a user