[tests] expand flex-attn test for vision models (#38434)
* expand the test for VLMs * typo * mark models `supports_flex` + expand test for additional kwargs * flex attn for refactored vision models * fix copies * fix * unskip * style * address comments
This commit is contained in:
committed by
GitHub
parent
de4cf5a38e
commit
bf68dd9e6e
@@ -535,6 +535,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
|
||||
)
|
||||
additional_model_inputs = ["pixel_values"]
|
||||
fx_compatible = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
||||
@@ -401,10 +401,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Emu3 doesn't support Flex attn yet!")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Emu3IntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -351,12 +351,6 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
|
||||
)
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
def test_automodelforcausallm(self):
|
||||
"""
|
||||
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||
|
||||
@@ -236,10 +236,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Vision backbone doesn't support FLEX yet!")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class GotOcr2IntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -569,6 +569,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
all_generative_model_classes = ()
|
||||
greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-to-audio": MusicgenForConditionalGeneration} if is_torch_available() else {}
|
||||
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
|
||||
additional_model_inputs = ["decoder_input_ids"]
|
||||
test_pruning = False # training is not supported yet for MusicGen
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -589,6 +589,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
all_generative_model_classes = ()
|
||||
greedy_sample_model_classes = (MusicgenMelodyForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-to-audio": MusicgenMelodyForConditionalGeneration} if is_torch_available() else {}
|
||||
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
|
||||
additional_model_inputs = ["decoder_input_ids"]
|
||||
test_pruning = False # training is not supported yet for MusicGen
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -103,7 +103,7 @@ class SiglipVisionModelTester:
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@@ -274,7 +274,7 @@ class SiglipTextModelTester:
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
|
||||
@@ -180,7 +180,7 @@ class Siglip2VisionModelTester:
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@@ -363,7 +363,7 @@ class Siglip2TextModelTester:
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
|
||||
@@ -190,7 +190,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
# Addition keys that are required for forward, used in tests where we manipulate and create new input dict from scratch
|
||||
additional_model_inputs = ["bool_masked_pos"]
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -322,10 +322,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA vision backbones doesn't support flex attention yet")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user