[attention] fix test for packed padfree masking (#39582)

* fix most tests

* skip a few more tests

* address comments

* fix chameleon tests

* forgot to uncomment

* qwen has its own tests with images, rename it as well
This commit is contained in:
Raushan Turganbay
2025-07-25 09:44:52 +02:00
committed by GitHub
parent 565c035a2e
commit c392d47c9b
17 changed files with 153 additions and 250 deletions

View File

@@ -4125,9 +4125,13 @@ class ModelTesterMixin:
assert not loss.isnan().any()
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
"""
Tests that the given attention implementation can work with packed sequences and infers the mask
from position ids. This test requires the model to use new attention mask API which handles packing.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
@@ -4142,17 +4146,27 @@ class ModelTesterMixin:
if attn_implementation != "eager" and not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
# can't infer if new attn mask API is supported by assume that only model with attention backend support it
if not model_class._supports_attention_backend:
self.skipTest(f"{model_class.__name__} does not support new attention mask API")
if model_class._is_stateful: # non-transformer models most probably have no packing support
self.skipTest(f"{model_class.__name__} doesn't support packing!")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.is_encoder_decoder:
self.skipTest("Model is an encoder-decoder")
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
if "input_ids" not in inputs_dict or inputs_dict["input_ids"].ndim != 2:
self.skipTest("Model dummy inputs should contain text input ids")
# make sure that all models have enough positions for generation
dummy_input_ids = inputs_dict["input_ids"]
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
@@ -4164,11 +4178,14 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# ensure left padding, to adapt for some models
# Drop all keys except for the minimal set. Hard to manipulate with multimodals/head_mask/etc
inputs_dict = {k: v for k, v in inputs_dict.items() if k in ["input_ids", "attention_mask"]}
# Ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
dummy_input_ids[~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
@@ -4183,8 +4200,7 @@ class ModelTesterMixin:
if fa_kwargs:
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
{"input_ids": i[a.bool()].tolist()} for i, a in zip(dummy_input_ids, dummy_attention_mask)
]
# add position_ids + fa_kwargs
@@ -4194,55 +4210,48 @@ class ModelTesterMixin:
k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
}
else:
# flatten
padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0)
for k, v in inputs_dict.items()
if not k == "attention_mask"
}
# add position_ids
padfree_inputs_dict["position_ids"] = (
# create packed position_ids
position_ids = (
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
.long()
.unsqueeze(0)
.to(torch_device)
)
padfree_inputs_dict = {
"input_ids": dummy_input_ids[dummy_attention_mask.bool()].unsqueeze(0),
"position_ids": position_ids,
}
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
# We need to do simple forward without cache in order to trigger packed SDPA/flex/eager attention path
res_padded = model(**inputs_dict, use_cache=False)
res_padfree = model(**padfree_inputs_dict, use_cache=False)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padded = res_padded.logits[dummy_attention_mask.bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
# FIXME @raushan
@slow
def test_eager_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
@slow
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(
self.attention_mask_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_2", fa_kwargs=True
)
@@ -4251,14 +4260,14 @@ class ModelTesterMixin:
@mark.flash_attn_3_test
@slow
def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(
self.attention_mask_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_3", fa_kwargs=True
)