[qwen2 vl] fix packing with all attentions (#39447)

* fix qwen2 vl packing in FA2

* why? delete!

* qwen2-5-vl seems to work now

* update

* fix tests

* start by adapting FA2 tests

* add similar tests for sdpa/eager

* address comments

* why is this even in conditional model and not base model?
This commit is contained in:
Raushan Turganbay
2025-07-21 12:19:15 +02:00
committed by GitHub
parent e42681b48b
commit 344012b3a6
9 changed files with 478 additions and 100 deletions

View File

@@ -4129,13 +4129,14 @@ class ModelTesterMixin:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn",
"flash_attention_3": "_supports_flash_attn",
}
for model_class in self.all_generative_model_classes:
if not (
model_class._supports_flash_attn_2
if attn_implementation == "flash_attention_2"
else model_class._supports_flash_attn_3
):
if not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -4204,8 +4205,9 @@ class ModelTesterMixin:
.to(torch_device)
)
res_padded = model(**inputs_dict)
res_padfree = model(**padfree_inputs_dict)
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
res_padded = model(**inputs_dict, use_cache=False)
res_padfree = model(**padfree_inputs_dict, use_cache=False)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
@@ -4215,6 +4217,16 @@ class ModelTesterMixin:
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
# FIXME @raushan
@slow
def test_eager_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
@slow
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test