[qwen2 vl] fix packing with all attentions (#39447)
* fix qwen2 vl packing in FA2 * why? delete! * qwen2-5-vl seems to work now * update * fix tests * start by adapting FA2 tests * add similar tests for sdpa/eager * address comments * why is this even in conditional model and not base model?
This commit is contained in:
committed by
GitHub
parent
e42681b48b
commit
344012b3a6
@@ -4129,13 +4129,14 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn",
|
||||
"flash_attention_3": "_supports_flash_attn",
|
||||
}
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not (
|
||||
model_class._supports_flash_attn_2
|
||||
if attn_implementation == "flash_attention_2"
|
||||
else model_class._supports_flash_attn_3
|
||||
):
|
||||
if not getattr(model_class, support_flag[attn_implementation]):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -4204,8 +4205,9 @@ class ModelTesterMixin:
|
||||
.to(torch_device)
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**padfree_inputs_dict)
|
||||
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
@@ -4215,6 +4217,16 @@ class ModelTesterMixin:
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
|
||||
# FIXME @raushan
|
||||
@slow
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
|
||||
|
||||
@slow
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user