[VLMs] support attention backends (#37576)

* update models

* why rename

* return attn weights when sdpa

* fixes

* fix attn implementation composite

* fix moshi

* add message

* add typings

* use explicitly all flags for each attn type

* fix some tests

* import what is needed

* kosmos on main has ew attention already, yay

* new models in main, run fixup

* won't fix kosmos yet

* fix-copies

* clean up after rebasing

* fix tests

* style

* dont cast attns to fp32

* did we update ruff? oke, let's just do what it asks

* fix pixtral after rebase
This commit is contained in:
Raushan Turganbay
2025-05-08 18:18:54 +02:00
committed by GitHub
parent e296c63cd4
commit d23aae2b8c
47 changed files with 1318 additions and 1555 deletions

View File

@@ -461,6 +461,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester:
@require_torch
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
additional_model_inputs = ["input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False
@@ -526,15 +527,11 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@@ -545,20 +542,13 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -869,6 +859,7 @@ class Blip2ModelTester:
@require_torch
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
additional_model_inputs = ["input_ids", "decoder_input_ids"]
# Doesn't run generation tests. TODO: fix generation tests for Blip2ForConditionalGeneration
all_generative_model_classes = ()
pipeline_model_mapping = (
@@ -967,15 +958,11 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "eager")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@@ -986,20 +973,13 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -1485,6 +1465,7 @@ class Blip2TextRetrievalModelTester:
@require_torch
class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else ()
additional_model_inputs = ["input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False