[tests] expand flex-attn test for vision models (#38434)

* expand the test for VLMs

* typo

* mark models `supports_flex` + expand test for additional kwargs

* flex attn for refactored vision models

* fix copies

* fix

* unskip

* style

* address comments
This commit is contained in:
Raushan Turganbay
2025-06-03 09:40:44 +02:00
committed by GitHub
parent de4cf5a38e
commit bf68dd9e6e
45 changed files with 429 additions and 195 deletions

View File

@@ -535,6 +535,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
pipeline_model_mapping = (
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
)
additional_model_inputs = ["pixel_values"]
fx_compatible = True
test_head_masking = False
test_pruning = False

View File

@@ -401,10 +401,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
def test_generate_with_static_cache(self):
pass
@unittest.skip("Emu3 doesn't support Flex attn yet!")
def test_flex_attention_with_grads(self):
pass
@require_torch
class Emu3IntegrationTest(unittest.TestCase):

View File

@@ -351,12 +351,6 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_initialization(self):
pass
@unittest.skip(
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
)
def test_flex_attention_with_grads(self):
pass
def test_automodelforcausallm(self):
"""
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that

View File

@@ -236,10 +236,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_past_key_values_format(self):
pass
@unittest.skip(reason="Vision backbone doesn't support FLEX yet!")
def test_flex_attention_with_grads(self):
pass
@require_torch
class GotOcr2IntegrationTest(unittest.TestCase):

View File

@@ -569,6 +569,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
all_generative_model_classes = ()
greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"text-to-audio": MusicgenForConditionalGeneration} if is_torch_available() else {}
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
additional_model_inputs = ["decoder_input_ids"]
test_pruning = False # training is not supported yet for MusicGen
test_headmasking = False
test_resize_embeddings = False

View File

@@ -589,6 +589,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
all_generative_model_classes = ()
greedy_sample_model_classes = (MusicgenMelodyForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"text-to-audio": MusicgenMelodyForConditionalGeneration} if is_torch_available() else {}
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
additional_model_inputs = ["decoder_input_ids"]
test_pruning = False # training is not supported yet for MusicGen
test_headmasking = False
test_resize_embeddings = False

View File

@@ -103,7 +103,7 @@ class SiglipVisionModelTester:
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
@@ -274,7 +274,7 @@ class SiglipTextModelTester:
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,

View File

@@ -180,7 +180,7 @@ class Siglip2VisionModelTester:
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
@@ -363,7 +363,7 @@ class Siglip2TextModelTester:
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,

View File

@@ -190,7 +190,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
if is_torch_available()
else {}
)
# Addition keys that are required for forward, used in tests where we manipulate and create new input dict from scratch
additional_model_inputs = ["bool_masked_pos"]
test_pruning = False
test_torchscript = False
test_resize_embeddings = False

View File

@@ -322,10 +322,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("LLaVA vision backbones doesn't support flex attention yet")
def test_flex_attention_with_grads(self):
pass
@require_torch
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@@ -3637,7 +3637,10 @@ class ModelTesterMixin:
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
for key in getattr(self, "additional_model_inputs", []):
processed_inputs[key] = inputs_dict[key]
# Some models don't have all `additional_model_inputs`, especially when we
# craft cases to test model in different settings
if key in inputs_dict:
processed_inputs[key] = inputs_dict[key]
for key, value in processed_inputs.items():
if torch.is_floating_point(value):
@@ -4012,19 +4015,21 @@ class ModelTesterMixin:
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
sub_models_supporting_fa2 = [
(module._supports_flash_attn_2 or module._supports_attention_backend)
module._supports_flash_attn_2
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_fa2_all_modules = (
all(sub_models_supporting_fa2)
if len(sub_models_supporting_fa2) > 0
else (model._supports_flash_attn_2 or model._supports_attention_backend)
else model._supports_flash_attn_2
)
if not supports_fa2_all_modules:
with self.assertRaises(ValueError):
model_fa2 = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
)
else:
model_fa2 = model_class.from_pretrained(
@@ -4572,33 +4577,73 @@ class ModelTesterMixin:
@require_torch_gpu
def test_flex_attention_with_grads(self):
for model_class in self.all_model_classes:
# TODO: raushan, fix for composite models after making VLMs support new attn API
if not model_class._supports_flex_attn or self._is_composite:
self.skipTest(reason="This model does not support flex attention")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
# Flex Attention cannot use dropout
if hasattr(config, "attention_dropout"):
config.attention_dropout = 0
if hasattr(config, "attention_probs_dropout_prob"):
config.attention_probs_dropout_prob = 0
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).to(device=torch_device)
# Flex attention relies on triton on compilation
# However, triton cannot handle hidden dimensions of less than 16
# --> forcing at least a hidden dim of 16
config.hidden_size *= max(
16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1
# If not all sub-models support flex, skip the test
sub_models_supporting_flex = [
module._supports_flex_attn
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or (
model._supports_flex_attn and len(sub_models_supporting_flex) == 0
)
if hasattr(config, "head_dim"):
config.head_dim = max(16, config.head_dim)
if not supports_flex_all_modules:
self.skipTest(reason="This model's submodels does not support flex attention")
def update_config_for_flex(config):
# Flex Attention cannot use dropout
if hasattr(config, "attention_dropout"):
config.attention_dropout = 0
if hasattr(config, "attention_probs_dropout_prob"):
config.attention_probs_dropout_prob = 0
# Flex attention relies on triton on compilation
# However, triton cannot handle hidden dimensions of less than 16
# --> forcing at least a hidden dim of 16
# Update the head dim and try to update hidden size as well if present in config
# NOTE: some models may have none if the values in sub-config, thus we check for `Noneness`
head_dim = None
if hasattr(config, "head_dim") and config.head_dim is not None:
head_dim = config.head_dim
config.head_dim = max(16, config.head_dim)
if (
getattr(config, "hidden_size", None) is not None
and getattr(config, "num_attention_heads", None) is not None
):
head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
config.hidden_size *= max(16 // head_dim, 1)
if (
getattr(config, "decoder_hidden_size", None) is not None
and getattr(config, "decoder_num_attention_heads", None) is not None
):
decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
config.decoder_hidden_size *= max(16 // decoder_head_dim, 1)
# Set default attention to flex and update config values
update_config_for_flex(config)
for key in config.sub_configs:
sub_config = getattr(config, key)
update_config_for_flex(sub_config)
config._attn_implementation = "flex_attention"
model = model_class(config).to(device=torch_device)
self.assertTrue(model.config._attn_implementation == "flex_attention")
# Elaborate workaround for encoder-decoder models as some do not specify their main input
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
if config.is_encoder_decoder:
for key in getattr(self, "additional_model_inputs", []):
# Some models don't have all `additional_model_inputs`, especially when we
# craft cases to test model in different settings
if key in inputs_dict:
dummy_inputs[key] = inputs_dict[key].to(torch_device)
if config.get_text_config(decoder=True).is_encoder_decoder:
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)