Enhancing SFT Training Efficiency Using Packing and FlashAttention2 with Position IDs (#31629)
* add DataCollatorBatchFlattening * Update data_collator.py * change name * new FA2 flow if position_ids is provided * add comments * minor fix * minor fix data collator * add test cases for models * add test case for data collator * remove extra code * formating for ruff check and check_repo.py * ruff format ruff format tests src utils * custom_init_isort.py
This commit is contained in:
@@ -4327,6 +4327,78 @@ class ModelTesterMixin:
|
||||
# with attention mask
|
||||
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
|
||||
# ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
k: v[dummy_attention_mask.bool()].unsqueeze(0)
|
||||
for k, v in inputs_dict.items()
|
||||
if not k == "attention_mask"
|
||||
}
|
||||
# add position_ids
|
||||
padfree_inputs_dict["position_ids"] = (
|
||||
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
|
||||
.long()
|
||||
.unsqueeze(0)
|
||||
.to(torch_device)
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**padfree_inputs_dict)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_tf_from_pt_safetensors(self):
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
@@ -26,6 +26,7 @@ from transformers import (
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorForTokenClassification,
|
||||
DataCollatorForWholeWordMask,
|
||||
DataCollatorWithFlattening,
|
||||
DataCollatorWithPadding,
|
||||
default_data_collator,
|
||||
is_tf_available,
|
||||
@@ -1531,6 +1532,24 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 8))
|
||||
|
||||
def test_data_collator_with_flattening(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="np")
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertNotIn("attention_mask", batch)
|
||||
self.assertIn("position_ids", batch)
|
||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
def test_data_collator_for_token_classification(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [
|
||||
|
||||
Reference in New Issue
Block a user