Add padding-free to bamba (#35861)
* add seq_idx and fa kwargs * update tests * docs and grad ckpt support * fmt * better names * test_raise_missing_padding_free_kwarg_errs * + seq_idx in doc strings * padding free training docs * add link to pr plots * raise err on attn_mask with padding free * rm raising missing padding free err test * BambaFlashAttentionKwargs * run modular util for modular_granitemoehybrid.py
This commit is contained in:
@@ -14,16 +14,25 @@
|
||||
"""Testing suite for the PyTorch Bamba model."""
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BambaConfig,
|
||||
DataCollatorWithFlattening,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_deterministic_for_xpu,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -489,6 +498,92 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# They should result in very similar logits
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@unittest.skip(
|
||||
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
|
||||
)
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
|
||||
)
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
self.skipTest("Model does not support position_ids")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
features = [
|
||||
{"input_ids": i[a.bool()].tolist()}
|
||||
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||
]
|
||||
|
||||
# add position_ids + fa_kwargs + seq_idx
|
||||
data_collator = DataCollatorWithFlattening(
|
||||
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
|
||||
)
|
||||
batch = data_collator(features)
|
||||
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**batch_cuda)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user