Add padding-free to bamba (#35861)

* add seq_idx and fa kwargs

* update tests

* docs and grad ckpt support

* fmt

* better names

* test_raise_missing_padding_free_kwarg_errs

* + seq_idx in doc strings

* padding free training docs

* add link to pr plots

* raise err on attn_mask with padding free

* rm raising missing padding free err test

* BambaFlashAttentionKwargs

* run modular util for modular_granitemoehybrid.py
This commit is contained in:
Garrett Goon
2025-05-20 11:13:59 -04:00
committed by GitHub
parent 2a79471318
commit 390f153469
5 changed files with 233 additions and 25 deletions

View File

@@ -14,16 +14,25 @@
"""Testing suite for the PyTorch Bamba model."""
import inspect
import tempfile
import unittest
import pytest
from pytest import mark
from transformers import AutoTokenizer, BambaConfig, is_torch_available
from transformers import (
AutoTokenizer,
BambaConfig,
DataCollatorWithFlattening,
is_torch_available,
)
from transformers.testing_utils import (
Expectations,
require_deterministic_for_xpu,
require_flash_attn,
require_torch,
require_torch_accelerator,
require_torch_gpu,
slow,
torch_device,
)
@@ -489,6 +498,92 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
@unittest.skip(
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip(
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
pass
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
)
.to(torch_device)
.eval()
)
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
]
# add position_ids + fa_kwargs + seq_idx
data_collator = DataCollatorWithFlattening(
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
)
batch = data_collator(features)
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
res_padded = model(**inputs_dict)
res_padfree = model(**batch_cuda)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@slow
@require_torch