[tests] expand flex-attn test for vision models (#38434)

* expand the test for VLMs

* typo

* mark models `supports_flex` + expand test for additional kwargs

* flex attn for refactored vision models

* fix copies

* fix

* unskip

* style

* address comments
This commit is contained in:
Raushan Turganbay
2025-06-03 09:40:44 +02:00
committed by GitHub
parent de4cf5a38e
commit bf68dd9e6e
45 changed files with 429 additions and 195 deletions

View File

@@ -535,6 +535,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
pipeline_model_mapping = (
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
)
additional_model_inputs = ["pixel_values"]
fx_compatible = True
test_head_masking = False
test_pruning = False

View File

@@ -401,10 +401,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
def test_generate_with_static_cache(self):
pass
@unittest.skip("Emu3 doesn't support Flex attn yet!")
def test_flex_attention_with_grads(self):
pass
@require_torch
class Emu3IntegrationTest(unittest.TestCase):

View File

@@ -351,12 +351,6 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_initialization(self):
pass
@unittest.skip(
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
)
def test_flex_attention_with_grads(self):
pass
def test_automodelforcausallm(self):
"""
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that

View File

@@ -236,10 +236,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_past_key_values_format(self):
pass
@unittest.skip(reason="Vision backbone doesn't support FLEX yet!")
def test_flex_attention_with_grads(self):
pass
@require_torch
class GotOcr2IntegrationTest(unittest.TestCase):

View File

@@ -569,6 +569,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
all_generative_model_classes = ()
greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"text-to-audio": MusicgenForConditionalGeneration} if is_torch_available() else {}
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
additional_model_inputs = ["decoder_input_ids"]
test_pruning = False # training is not supported yet for MusicGen
test_headmasking = False
test_resize_embeddings = False

View File

@@ -589,6 +589,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
all_generative_model_classes = ()
greedy_sample_model_classes = (MusicgenMelodyForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"text-to-audio": MusicgenMelodyForConditionalGeneration} if is_torch_available() else {}
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
additional_model_inputs = ["decoder_input_ids"]
test_pruning = False # training is not supported yet for MusicGen
test_headmasking = False
test_resize_embeddings = False

View File

@@ -103,7 +103,7 @@ class SiglipVisionModelTester:
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
@@ -274,7 +274,7 @@ class SiglipTextModelTester:
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,

View File

@@ -180,7 +180,7 @@ class Siglip2VisionModelTester:
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
@@ -363,7 +363,7 @@ class Siglip2TextModelTester:
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,

View File

@@ -190,7 +190,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
if is_torch_available()
else {}
)
# Addition keys that are required for forward, used in tests where we manipulate and create new input dict from scratch
additional_model_inputs = ["bool_masked_pos"]
test_pruning = False
test_torchscript = False
test_resize_embeddings = False

View File

@@ -322,10 +322,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("LLaVA vision backbones doesn't support flex attention yet")
def test_flex_attention_with_grads(self):
pass
@require_torch
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):