From 503541d7efeb3750944cc716ddc0e068b0cf0a48 Mon Sep 17 00:00:00 2001 From: Garrett Goon <44747910+garrett361@users.noreply.github.com> Date: Wed, 16 Apr 2025 09:45:03 -0400 Subject: [PATCH] add FlashAttentionKwargs and seq_idx to flat collator (#36456) * add flash attn kwargs to flattening collator * add return_seq_idx option * doc string edits * cleaner max len updates * various fixes * temp testing code * return int32 seq_idx and FlashAttnKwargs * DataCollatorIntegrationTest impl * fix batch dims and dtypes * fill out remaining collator tests * test name change and fmt * rm unused var * fmt * minor change * fmt * add missing pos_ids check * consistent {np,pt,tf} tests * split pt tests into 3, like np/tf tests * mv comment, rename fa test * remove batch dim comment * simply wrapping * compute cu_seq_len/max_length once * fmt * remove tf code * rm warning * move separator_id back to 2nd pos * use cleaner lists in tests * ret -> batch * fmt * attr ordering * use py ints for max_length_{k,q} --- src/transformers/data/data_collator.py | 74 ++++++++-- tests/test_modeling_common.py | 73 ++++++++++ tests/trainer/test_data_collator.py | 184 ++++++++++++++++++++++++- 3 files changed, 318 insertions(+), 13 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 07490a25f9..55aed55a13 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1974,9 +1974,11 @@ class DataCollatorWithFlattening(DefaultDataCollator): """ Data collator used for padding free approach. Does the following: - - concatate the entire mini batch into single long sequence [1, total_tokens] + - concatenates the entire mini batch into single long sequence of shape [1, total_tokens] - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100 - - no padding will be added, returns `input_ids`, `labels` and `position_ids` + - no padding will be added, returns `input_ids`, `labels` and `position_ids` by default + - optionally returns the kwargs contained in FlashAttentionKwargs + - optionally returns seq_idx indicating which sequence each token belongs to @@ -1986,10 +1988,23 @@ class DataCollatorWithFlattening(DefaultDataCollator): """ - def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs): + def __init__( + self, + *args, + return_position_ids=True, + separator_id=-100, + return_flash_attn_kwargs=False, + return_seq_idx=False, + **kwargs, + ): super().__init__(*args, **kwargs) self.return_position_ids = return_position_ids self.separator_id = separator_id + self.return_flash_attn_kwargs = return_flash_attn_kwargs + self.return_seq_idx = return_seq_idx + self._int_64_keys = {"labels", "position_ids", "input_ids"} + self._batch_dim_keys = {"labels", "position_ids", "input_ids", "seq_idx"} + self._py_int_keys = {"max_length_q", "max_length_k"} def __call__(self, features, return_tensors=None, separator_id=None): if return_tensors is None: @@ -1997,15 +2012,52 @@ class DataCollatorWithFlattening(DefaultDataCollator): if separator_id is None: separator_id = self.separator_id is_labels_provided = "labels" in features[0] - ret = {"input_ids": [], "labels": []} + batch = {"input_ids": [], "labels": []} if self.return_position_ids: - ret.update({"position_ids": []}) - for idx in range(0, len(features)): - ret["input_ids"] += features[idx]["input_ids"] + batch.update({"position_ids": []}) + if self.return_seq_idx: + batch.update({"seq_idx": []}) + if self.return_flash_attn_kwargs: + cu_seq_lens = [0] + max_length = 0 + for seq_idx, sample in enumerate(features): + input_ids = sample["input_ids"] + batch["input_ids"] += input_ids if is_labels_provided: - ret["labels"] += [separator_id] + features[idx]["labels"][1:] + batch["labels"] += [separator_id] + sample["labels"][1:] else: - ret["labels"] += [separator_id] + features[idx]["input_ids"][1:] + batch["labels"] += [separator_id] + input_ids[1:] if self.return_position_ids: - ret["position_ids"] += list(range(len(features[idx]["input_ids"]))) - return default_data_collator([ret], return_tensors) + batch["position_ids"] += list(range(len(input_ids))) + if self.return_seq_idx: + batch["seq_idx"] += [seq_idx for _ in range(len(input_ids))] + if self.return_flash_attn_kwargs: + cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids)) + max_length = max(max_length, len(input_ids)) + + if self.return_flash_attn_kwargs: + batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens + batch["max_length_q"] = batch["max_length_k"] = max_length + + # FlashAttentionKwargs and seq_idx are expected to be int32s. + if return_tensors == "pt": + import torch + + data_cls = torch.tensor + dtype_64 = torch.int64 + dtype_32 = torch.int32 + elif return_tensors == "np": + data_cls = np.array + dtype_64 = np.int64 + dtype_32 = np.int32 + else: + raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not suported') + + for k, v in batch.items(): + if k in self._batch_dim_keys: + v = [v] + # Flash attention max_len_{q,k} are python ints + if k not in self._py_int_keys: + batch[k] = data_cls(v, dtype=dtype_64 if k in self._int_64_keys else dtype_32) + + return batch diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3e360c05f0..00fb2a77d8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -34,6 +34,7 @@ from transformers import ( AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, + DataCollatorWithFlattening, PretrainedConfig, PreTrainedModel, is_torch_available, @@ -4170,6 +4171,78 @@ class ModelTesterMixin: tol = torch.finfo(torch.float16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: + self.skipTest("Model dummy inputs should contain padding in their attention mask") + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + if "position_ids" not in inspect.signature(model.forward).parameters: + self.skipTest("Model does not support position_ids") + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # ensure left padding, to adapt for some models + if 0 in inputs_dict["attention_mask"][:, -1]: + inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) + dummy_attention_mask = inputs_dict["attention_mask"] + inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id + + model = ( + model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ) + .to(torch_device) + .eval() + ) + + # flatten + features = [ + {"input_ids": i[a.bool()].tolist()} + for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"]) + ] + + # add position_ids + fa_kwargs + data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True) + batch = data_collator(features) + batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()} + + res_padded = model(**inputs_dict) + res_padfree = model(**batch_cuda) + + logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] + logits_padfree = res_padfree.logits[0] + + torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) + # acceptable numerical instability + tol = torch.finfo(torch.float16).eps + torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index a88641ca16..d4360c32c9 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -126,6 +126,104 @@ class DataCollatorIntegrationTest(unittest.TestCase): batch = data_collator(features) self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8])) + def test_data_collator_with_flattening(self): + features = [ + {"input_ids": [10, 11, 12]}, + {"input_ids": [20, 21, 22, 23, 24, 25]}, + {"input_ids": [30, 31, 32, 33, 34, 35, 36]}, + ] + + data_collator = DataCollatorWithFlattening(return_tensors="pt") + batch = data_collator(features) + + for unexpected_key in [ + "attention_mask", + "cu_seq_lens_k", + "cu_seq_lens_q", + "max_length_k", + "max_length_q", + "seq_idx", + ]: + self.assertNotIn(unexpected_key, batch) + self.assertIn("position_ids", batch) + + self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16])) + self.assertEqual( + batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36] + ) + self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16])) + self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]) + + def test_data_collator_with_flattening_flash_attn_kwargs(self): + features = [ + {"input_ids": [10, 11, 12]}, + {"input_ids": [20, 21, 22, 23, 24, 25]}, + {"input_ids": [30, 31, 32, 33, 34, 35, 36]}, + ] + data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True) + batch = data_collator(features) + + for unexpected_key in [ + "attention_mask", + "seq_idx", + ]: + self.assertNotIn(unexpected_key, batch) + for expected_key in [ + "position_ids", + "cu_seq_lens_k", + "cu_seq_lens_q", + "max_length_k", + "max_length_q", + ]: + self.assertIn(expected_key, batch) + + self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16])) + self.assertEqual( + batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36] + ) + self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16])) + self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]) + + self.assertEqual(batch["cu_seq_lens_k"].shape, torch.Size([4])) + self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16]) + self.assertEqual(batch["cu_seq_lens_q"].shape, torch.Size([4])) + self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16]) + # The flash attn max_length_{k,q} are simple python ints + self.assertEqual(batch["max_length_k"], 7) + self.assertEqual(batch["max_length_q"], 7) + + def test_data_collator_with_flattening_seq_idx(self): + features = [ + {"input_ids": [10, 11, 12]}, + {"input_ids": [20, 21, 22, 23, 24, 25]}, + {"input_ids": [30, 31, 32, 33, 34, 35, 36]}, + ] + data_collator = DataCollatorWithFlattening(return_tensors="pt", return_seq_idx=True) + batch = data_collator(features) + + for unexpected_key in [ + "attention_mask", + "cu_seq_lens_k", + "cu_seq_lens_q", + "max_length_k", + "max_length_q", + ]: + self.assertNotIn(unexpected_key, batch) + for expected_key in [ + "position_ids", + "seq_idx", + ]: + self.assertIn(expected_key, batch) + + self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16])) + self.assertEqual( + batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36] + ) + self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16])) + self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]) + self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape) + self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]) + def test_data_collator_for_token_classification(self): tokenizer = BertTokenizer(self.vocab_file) features = [ @@ -1803,15 +1901,97 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase): data_collator = DataCollatorWithFlattening(return_tensors="np") batch = data_collator(features) + + for unexpected_key in [ + "attention_mask", + "cu_seq_lens_k", + "cu_seq_lens_q", + "max_length_k", + "max_length_q", + "seq_idx", + ]: + self.assertNotIn(unexpected_key, batch) + self.assertIn("position_ids", batch) + self.assertEqual(batch["input_ids"].shape, (1, 16)) self.assertEqual( batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36] ) - self.assertNotIn("attention_mask", batch) - self.assertIn("position_ids", batch) self.assertEqual(batch["position_ids"].shape, (1, 16)) self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]) + def test_data_collator_with_flattening_flash_attn_kwargs(self): + features = [ + {"input_ids": [10, 11, 12]}, + {"input_ids": [20, 21, 22, 23, 24, 25]}, + {"input_ids": [30, 31, 32, 33, 34, 35, 36]}, + ] + + data_collator = DataCollatorWithFlattening(return_tensors="np", return_flash_attn_kwargs=True) + batch = data_collator(features) + + for unexpected_key in [ + "attention_mask", + "seq_idx", + ]: + self.assertNotIn(unexpected_key, batch) + for expected_key in [ + "position_ids", + "cu_seq_lens_k", + "cu_seq_lens_q", + "max_length_k", + "max_length_q", + ]: + self.assertIn(expected_key, batch) + + self.assertEqual(batch["input_ids"].shape, (1, 16)) + self.assertEqual( + batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36] + ) + self.assertEqual(batch["position_ids"].shape, (1, 16)) + self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]) + + self.assertEqual(batch["cu_seq_lens_k"].shape, (4,)) + self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16]) + self.assertEqual(batch["cu_seq_lens_q"].shape, (4,)) + self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16]) + # The flash attn max_length_{k,q} are simple python ints + self.assertEqual(batch["max_length_k"], 7) + self.assertEqual(batch["max_length_q"], 7) + + def test_data_collator_with_flattening_seq_idx(self): + features = [ + {"input_ids": [10, 11, 12]}, + {"input_ids": [20, 21, 22, 23, 24, 25]}, + {"input_ids": [30, 31, 32, 33, 34, 35, 36]}, + ] + + data_collator = DataCollatorWithFlattening(return_tensors="np", return_seq_idx=True) + batch = data_collator(features) + + for unexpected_key in [ + "attention_mask", + "cu_seq_lens_k", + "cu_seq_lens_q", + "max_length_k", + "max_length_q", + ]: + self.assertNotIn(unexpected_key, batch) + for expected_key in [ + "position_ids", + "seq_idx", + ]: + self.assertIn(expected_key, batch) + + self.assertEqual(batch["input_ids"].shape, (1, 16)) + self.assertEqual( + batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36] + ) + self.assertEqual(batch["position_ids"].shape, (1, 16)) + self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]) + self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape) + self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]) + def test_data_collator_for_token_classification(self): tokenizer = BertTokenizer(self.vocab_file) features = [