Enhancing SFT Training Efficiency Using Packing and FlashAttention2 with Position IDs (#31629)

* add DataCollatorBatchFlattening

* Update data_collator.py

* change name

* new FA2 flow if position_ids is provided

* add comments

* minor fix

* minor fix data collator

* add test cases for models

* add test case for data collator

* remove extra code

* formating for ruff check and check_repo.py

* ruff format

ruff format tests src utils

* custom_init_isort.py
This commit is contained in:
RhuiDih
2024-07-23 21:56:41 +08:00
committed by GitHub
parent 7d92009af6
commit 9cf4f2aa9a
20 changed files with 226 additions and 0 deletions

View File

@@ -26,6 +26,7 @@ from transformers import (
DataCollatorForSeq2Seq,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
default_data_collator,
is_tf_available,
@@ -1531,6 +1532,24 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (2, 8))
def test_data_collator_with_flattening(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="np")
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertNotIn("attention_mask", batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [