add FlashAttentionKwargs and seq_idx to flat collator (#36456)

* add flash attn kwargs to flattening collator

* add return_seq_idx option

* doc string edits

* cleaner max len updates

* various fixes

* temp testing code

* return int32 seq_idx and FlashAttnKwargs

* DataCollatorIntegrationTest impl

* fix batch dims and dtypes

* fill out remaining collator tests

* test name change and fmt

* rm unused var

* fmt

* minor change

* fmt

* add missing pos_ids check

* consistent {np,pt,tf} tests

* split pt tests into 3, like np/tf tests

* mv comment, rename fa test

* remove batch dim comment

* simply wrapping

* compute cu_seq_len/max_length once

* fmt

* remove tf code

* rm warning

* move separator_id back to 2nd pos

* use cleaner lists in tests

* ret -> batch

* fmt

* attr ordering

* use py ints for max_length_{k,q}
This commit is contained in:
Garrett Goon
2025-04-16 09:45:03 -04:00
committed by GitHub
parent 9ddcf5fce5
commit 503541d7ef
3 changed files with 318 additions and 13 deletions

View File

@@ -126,6 +126,104 @@ class DataCollatorIntegrationTest(unittest.TestCase):
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
def test_data_collator_with_flattening(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="pt")
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
"seq_idx",
]:
self.assertNotIn(unexpected_key, batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
def test_data_collator_with_flattening_flash_attn_kwargs(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"seq_idx",
]:
self.assertNotIn(unexpected_key, batch)
for expected_key in [
"position_ids",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
]:
self.assertIn(expected_key, batch)
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
self.assertEqual(batch["cu_seq_lens_k"].shape, torch.Size([4]))
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
self.assertEqual(batch["cu_seq_lens_q"].shape, torch.Size([4]))
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
# The flash attn max_length_{k,q} are simple python ints
self.assertEqual(batch["max_length_k"], 7)
self.assertEqual(batch["max_length_q"], 7)
def test_data_collator_with_flattening_seq_idx(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_seq_idx=True)
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
]:
self.assertNotIn(unexpected_key, batch)
for expected_key in [
"position_ids",
"seq_idx",
]:
self.assertIn(expected_key, batch)
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [
@@ -1803,15 +1901,97 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
data_collator = DataCollatorWithFlattening(return_tensors="np")
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
"seq_idx",
]:
self.assertNotIn(unexpected_key, batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertNotIn("attention_mask", batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
def test_data_collator_with_flattening_flash_attn_kwargs(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="np", return_flash_attn_kwargs=True)
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"seq_idx",
]:
self.assertNotIn(unexpected_key, batch)
for expected_key in [
"position_ids",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
]:
self.assertIn(expected_key, batch)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
self.assertEqual(batch["cu_seq_lens_k"].shape, (4,))
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
self.assertEqual(batch["cu_seq_lens_q"].shape, (4,))
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
# The flash attn max_length_{k,q} are simple python ints
self.assertEqual(batch["max_length_k"], 7)
self.assertEqual(batch["max_length_q"], 7)
def test_data_collator_with_flattening_seq_idx(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="np", return_seq_idx=True)
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
]:
self.assertNotIn(unexpected_key, batch)
for expected_key in [
"position_ids",
"seq_idx",
]:
self.assertIn(expected_key, batch)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [