[VLMs] support attention backends (#37576)

* update models

* why rename

* return attn weights when sdpa

* fixes

* fix attn implementation composite

* fix moshi

* add message

* add typings

* use explicitly all flags for each attn type

* fix some tests

* import what is needed

* kosmos on main has ew attention already, yay

* new models in main, run fixup

* won't fix kosmos yet

* fix-copies

* clean up after rebasing

* fix tests

* style

* dont cast attns to fp32

* did we update ruff? oke, let's just do what it asks

* fix pixtral after rebase
This commit is contained in:
Raushan Turganbay
2025-05-08 18:18:54 +02:00
committed by GitHub
parent e296c63cd4
commit d23aae2b8c
47 changed files with 1318 additions and 1555 deletions

View File

@@ -461,6 +461,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester:
@require_torch
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
additional_model_inputs = ["input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False
@@ -526,15 +527,11 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@@ -545,20 +542,13 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -869,6 +859,7 @@ class Blip2ModelTester:
@require_torch
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
additional_model_inputs = ["input_ids", "decoder_input_ids"]
# Doesn't run generation tests. TODO: fix generation tests for Blip2ForConditionalGeneration
all_generative_model_classes = ()
pipeline_model_mapping = (
@@ -967,15 +958,11 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "eager")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@@ -986,20 +973,13 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -1485,6 +1465,7 @@ class Blip2TextRetrievalModelTester:
@require_torch
class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else ()
additional_model_inputs = ["input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False

View File

@@ -475,6 +475,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
else ()
)
pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration}
additional_model_inputs = ["qformer_input_ids", "input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False
@@ -687,15 +688,11 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@@ -706,20 +703,13 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
# We will verify our results on an image of cute cats
def prepare_img():

View File

@@ -492,6 +492,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
all_model_classes = (
(InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel) if is_torch_available() else ()
)
additional_model_inputs = ["qformer_input_ids", "input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False
@@ -702,15 +703,11 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@@ -721,20 +718,13 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
# We will verify our results on an image of cute cats
def prepare_video():

View File

@@ -30,6 +30,7 @@ from transformers.testing_utils import (
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@@ -42,6 +43,7 @@ from transformers.utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -259,6 +261,7 @@ class Kosmos2ModelTester:
@require_torch
class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else ()
additional_model_inputs = ["input_ids", "image_embeds_position_mask"]
pipeline_model_mapping = (
{
"feature-extraction": Kosmos2Model,
@@ -462,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_from_inputs_embeds(self):
pass
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask

View File

@@ -219,9 +219,10 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
else {}
)
is_encoder_decoder = False
fx_compatible = True
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
test_pruning = False
test_missing_keys = False
test_head_masking = False # new attn API doesn't support head mask
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(

View File

@@ -109,6 +109,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
"""
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
additional_model_inputs = ["image_sizes"]
test_pruning = False
test_head_masking = False
test_torchscript = False

View File

@@ -3765,6 +3765,10 @@ class ModelTesterMixin:
key = "decoder_hidden_states"
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
key = "logits"
elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower():
outputs_eager = outputs_eager["language_model_outputs"]
outputs_sdpa = outputs_sdpa["language_model_outputs"]
key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states"
else:
key = "hidden_states"
@@ -4002,14 +4006,14 @@ class ModelTesterMixin:
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
sub_models_supporting_fa2 = [
module._supports_flash_attn_2
(module._supports_flash_attn_2 or module._supports_attention_backend)
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_fa2_all_modules = (
all(sub_models_supporting_fa2)
if len(sub_models_supporting_fa2) > 0
else model._supports_flash_attn_2
else (model._supports_flash_attn_2 or model._supports_attention_backend)
)
if not supports_fa2_all_modules:
with self.assertRaises(ValueError):