[VLMs] support attention backends (#37576)

* update models

* why rename

* return attn weights when sdpa

* fixes

* fix attn implementation composite

* fix moshi

* add message

* add typings

* use explicitly all flags for each attn type

* fix some tests

* import what is needed

* kosmos on main has ew attention already, yay

* new models in main, run fixup

* won't fix kosmos yet

* fix-copies

* clean up after rebasing

* fix tests

* style

* dont cast attns to fp32

* did we update ruff? oke, let's just do what it asks

* fix pixtral after rebase
This commit is contained in:
Raushan Turganbay
2025-05-08 18:18:54 +02:00
committed by GitHub
parent e296c63cd4
commit d23aae2b8c
47 changed files with 1318 additions and 1555 deletions

View File

@@ -30,6 +30,7 @@ from transformers.testing_utils import (
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@@ -42,6 +43,7 @@ from transformers.utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -259,6 +261,7 @@ class Kosmos2ModelTester:
@require_torch
class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else ()
additional_model_inputs = ["input_ids", "image_embeds_position_mask"]
pipeline_model_mapping = (
{
"feature-extraction": Kosmos2Model,
@@ -462,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_from_inputs_embeds(self):
pass
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask