Enhancing SFT Training Efficiency Using Packing and FlashAttention2 with Position IDs (#31629)
* add DataCollatorBatchFlattening * Update data_collator.py * change name * new FA2 flow if position_ids is provided * add comments * minor fix * minor fix data collator * add test cases for models * add test case for data collator * remove extra code * formating for ruff check and check_repo.py * ruff format ruff format tests src utils * custom_init_isort.py
This commit is contained in:
@@ -66,3 +66,8 @@ Examples of use can be found in the [example scripts](../examples) or [example n
|
||||
- numpy_mask_tokens
|
||||
- tf_mask_tokens
|
||||
- torch_mask_tokens
|
||||
|
||||
## DataCollatorWithFlattening
|
||||
|
||||
[[autodoc]] data.data_collator.DataCollatorWithFlattening
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ _import_structure = {
|
||||
"DataCollatorForSOP",
|
||||
"DataCollatorForTokenClassification",
|
||||
"DataCollatorForWholeWordMask",
|
||||
"DataCollatorWithFlattening",
|
||||
"DataCollatorWithPadding",
|
||||
"DefaultDataCollator",
|
||||
"default_data_collator",
|
||||
@@ -4764,6 +4765,7 @@ if TYPE_CHECKING:
|
||||
DataCollatorForSOP,
|
||||
DataCollatorForTokenClassification,
|
||||
DataCollatorForWholeWordMask,
|
||||
DataCollatorWithFlattening,
|
||||
DataCollatorWithPadding,
|
||||
DefaultDataCollator,
|
||||
default_data_collator,
|
||||
|
||||
@@ -19,6 +19,7 @@ from .data_collator import (
|
||||
DataCollatorForSOP,
|
||||
DataCollatorForTokenClassification,
|
||||
DataCollatorForWholeWordMask,
|
||||
DataCollatorWithFlattening,
|
||||
DataCollatorWithPadding,
|
||||
DefaultDataCollator,
|
||||
default_data_collator,
|
||||
|
||||
@@ -1611,3 +1611,38 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
|
||||
) & masked_indices[i]
|
||||
|
||||
return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorWithFlattening(DefaultDataCollator):
|
||||
"""
|
||||
Data collator used for padding free approach. Does the following:
|
||||
|
||||
- concatate the entire mini batch into single long sequence [1, total_tokens]
|
||||
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
|
||||
"""
|
||||
|
||||
def __init__(self, *args, return_position_ids=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.return_position_ids = return_position_ids
|
||||
warnings.warn(
|
||||
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
|
||||
"Make sure your attention computation is able to handle it!"
|
||||
)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
is_labels_provided = "labels" in features[0]
|
||||
ret = {"input_ids": [], "labels": []}
|
||||
if self.return_position_ids:
|
||||
ret.update({"position_ids": []})
|
||||
for idx in range(0, len(features)):
|
||||
ret["input_ids"] += features[idx]["input_ids"]
|
||||
if is_labels_provided:
|
||||
ret["labels"] += [-100] + features[idx]["labels"][1:]
|
||||
else:
|
||||
ret["labels"] += [-100] + features[idx]["input_ids"][1:]
|
||||
if self.return_position_ids:
|
||||
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
|
||||
return default_data_collator([ret], return_tensors)
|
||||
|
||||
@@ -130,6 +130,56 @@ def _upad_input(
|
||||
)
|
||||
|
||||
|
||||
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
Cummulative lengths of each examples in the batch will be extracted from position_ids.
|
||||
|
||||
NOTE: ideally cummulative lengths should be prepared at the data collator stage
|
||||
|
||||
Arguments:
|
||||
query (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
|
||||
Return:
|
||||
query (`torch.Tensor):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
indices_q (`torch.Tensor`):
|
||||
The indices of non-masked tokens from the flattened input target sequence.
|
||||
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
|
||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
query = query.view(-1, query.size(-2), query.size(-1))
|
||||
key = key.view(-1, key.size(-2), key.size(-1))
|
||||
value = value.view(-1, value.size(-2), value.size(-1))
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||
)
|
||||
)
|
||||
|
||||
max_length = position_ids.max() + 1
|
||||
|
||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||
|
||||
|
||||
def _flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
@@ -138,6 +188,7 @@ def _flash_attention_forward(
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
use_top_left_mask: bool = False,
|
||||
@@ -210,6 +261,34 @@ def _flash_attention_forward(
|
||||
**flash_kwargs,
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
|
||||
# if position_ids is provided and check not all examples (row) contain only 1 sequence,
|
||||
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all():
|
||||
batch_size = query_states.size(0)
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||
query_states, key_states, value_states, position_ids
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||
|
||||
@@ -415,6 +415,7 @@ class DbrxFlashAttention2(DbrxAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
|
||||
@@ -602,6 +602,7 @@ class FalconFlashAttention2(FalconAttention):
|
||||
value_layer,
|
||||
attention_mask,
|
||||
query_length,
|
||||
position_ids=position_ids,
|
||||
dropout=attn_dropout,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
|
||||
@@ -393,6 +393,7 @@ class GemmaFlashAttention2(GemmaAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
is_causal=self.is_causal,
|
||||
|
||||
@@ -503,6 +503,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
|
||||
@@ -382,6 +382,7 @@ class MistralFlashAttention2(MistralAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self.config, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
|
||||
@@ -488,6 +488,7 @@ class MixtralFlashAttention2(MixtralAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self.config, "sliding_window", None),
|
||||
is_causal=self.is_causal,
|
||||
|
||||
@@ -428,6 +428,7 @@ class OlmoFlashAttention2(OlmoAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
|
||||
@@ -501,6 +501,7 @@ class PhiFlashAttention2(PhiAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=attn_dropout,
|
||||
softmax_scale=None,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
|
||||
@@ -563,6 +563,7 @@ class Phi3FlashAttention2(Phi3Attention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=attn_dropout,
|
||||
sliding_window=getattr(self.config, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
|
||||
@@ -429,6 +429,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
|
||||
@@ -508,6 +508,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
|
||||
@@ -606,6 +606,7 @@ class StableLmFlashAttention2(StableLmAttention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
|
||||
@@ -404,6 +404,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self.config, "sliding_window", None),
|
||||
is_causal=self.is_causal,
|
||||
|
||||
@@ -4327,6 +4327,78 @@ class ModelTesterMixin:
|
||||
# with attention mask
|
||||
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
|
||||
# ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
k: v[dummy_attention_mask.bool()].unsqueeze(0)
|
||||
for k, v in inputs_dict.items()
|
||||
if not k == "attention_mask"
|
||||
}
|
||||
# add position_ids
|
||||
padfree_inputs_dict["position_ids"] = (
|
||||
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
|
||||
.long()
|
||||
.unsqueeze(0)
|
||||
.to(torch_device)
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**padfree_inputs_dict)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_tf_from_pt_safetensors(self):
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
@@ -26,6 +26,7 @@ from transformers import (
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorForTokenClassification,
|
||||
DataCollatorForWholeWordMask,
|
||||
DataCollatorWithFlattening,
|
||||
DataCollatorWithPadding,
|
||||
default_data_collator,
|
||||
is_tf_available,
|
||||
@@ -1531,6 +1532,24 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, (2, 8))
|
||||
|
||||
def test_data_collator_with_flattening(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="np")
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertNotIn("attention_mask", batch)
|
||||
self.assertIn("position_ids", batch)
|
||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
def test_data_collator_for_token_classification(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [
|
||||
|
||||
Reference in New Issue
Block a user