[tests] expand flex-attn test for vision models (#38434)
* expand the test for VLMs * typo * mark models `supports_flex` + expand test for additional kwargs * flex attn for refactored vision models * fix copies * fix * unskip * style * address comments
This commit is contained in:
committed by
GitHub
parent
de4cf5a38e
commit
bf68dd9e6e
@@ -535,6 +535,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
|
||||
)
|
||||
additional_model_inputs = ["pixel_values"]
|
||||
fx_compatible = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
||||
@@ -401,10 +401,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Emu3 doesn't support Flex attn yet!")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Emu3IntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -351,12 +351,6 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
|
||||
)
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
def test_automodelforcausallm(self):
|
||||
"""
|
||||
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||
|
||||
@@ -236,10 +236,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Vision backbone doesn't support FLEX yet!")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class GotOcr2IntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -569,6 +569,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
all_generative_model_classes = ()
|
||||
greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-to-audio": MusicgenForConditionalGeneration} if is_torch_available() else {}
|
||||
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
|
||||
additional_model_inputs = ["decoder_input_ids"]
|
||||
test_pruning = False # training is not supported yet for MusicGen
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -589,6 +589,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
all_generative_model_classes = ()
|
||||
greedy_sample_model_classes = (MusicgenMelodyForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-to-audio": MusicgenMelodyForConditionalGeneration} if is_torch_available() else {}
|
||||
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
|
||||
additional_model_inputs = ["decoder_input_ids"]
|
||||
test_pruning = False # training is not supported yet for MusicGen
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -103,7 +103,7 @@ class SiglipVisionModelTester:
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@@ -274,7 +274,7 @@ class SiglipTextModelTester:
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
|
||||
@@ -180,7 +180,7 @@ class Siglip2VisionModelTester:
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@@ -363,7 +363,7 @@ class Siglip2TextModelTester:
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
|
||||
@@ -190,7 +190,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
# Addition keys that are required for forward, used in tests where we manipulate and create new input dict from scratch
|
||||
additional_model_inputs = ["bool_masked_pos"]
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -322,10 +322,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA vision backbones doesn't support flex attention yet")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -3637,7 +3637,10 @@ class ModelTesterMixin:
|
||||
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
|
||||
|
||||
for key in getattr(self, "additional_model_inputs", []):
|
||||
processed_inputs[key] = inputs_dict[key]
|
||||
# Some models don't have all `additional_model_inputs`, especially when we
|
||||
# craft cases to test model in different settings
|
||||
if key in inputs_dict:
|
||||
processed_inputs[key] = inputs_dict[key]
|
||||
|
||||
for key, value in processed_inputs.items():
|
||||
if torch.is_floating_point(value):
|
||||
@@ -4012,19 +4015,21 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
|
||||
sub_models_supporting_fa2 = [
|
||||
(module._supports_flash_attn_2 or module._supports_attention_backend)
|
||||
module._supports_flash_attn_2
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_fa2_all_modules = (
|
||||
all(sub_models_supporting_fa2)
|
||||
if len(sub_models_supporting_fa2) > 0
|
||||
else (model._supports_flash_attn_2 or model._supports_attention_backend)
|
||||
else model._supports_flash_attn_2
|
||||
)
|
||||
if not supports_fa2_all_modules:
|
||||
with self.assertRaises(ValueError):
|
||||
model_fa2 = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
else:
|
||||
model_fa2 = model_class.from_pretrained(
|
||||
@@ -4572,33 +4577,73 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
def test_flex_attention_with_grads(self):
|
||||
for model_class in self.all_model_classes:
|
||||
# TODO: raushan, fix for composite models after making VLMs support new attn API
|
||||
if not model_class._supports_flex_attn or self._is_composite:
|
||||
self.skipTest(reason="This model does not support flex attention")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config._attn_implementation = "flex_attention"
|
||||
# Flex Attention cannot use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).to(device=torch_device)
|
||||
|
||||
# Flex attention relies on triton on compilation
|
||||
# However, triton cannot handle hidden dimensions of less than 16
|
||||
# --> forcing at least a hidden dim of 16
|
||||
config.hidden_size *= max(
|
||||
16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1
|
||||
# If not all sub-models support flex, skip the test
|
||||
sub_models_supporting_flex = [
|
||||
module._supports_flex_attn
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or (
|
||||
model._supports_flex_attn and len(sub_models_supporting_flex) == 0
|
||||
)
|
||||
if hasattr(config, "head_dim"):
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
if not supports_flex_all_modules:
|
||||
self.skipTest(reason="This model's submodels does not support flex attention")
|
||||
|
||||
def update_config_for_flex(config):
|
||||
# Flex Attention cannot use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
# Flex attention relies on triton on compilation
|
||||
# However, triton cannot handle hidden dimensions of less than 16
|
||||
# --> forcing at least a hidden dim of 16
|
||||
|
||||
# Update the head dim and try to update hidden size as well if present in config
|
||||
# NOTE: some models may have none if the values in sub-config, thus we check for `Noneness`
|
||||
head_dim = None
|
||||
if hasattr(config, "head_dim") and config.head_dim is not None:
|
||||
head_dim = config.head_dim
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
|
||||
if (
|
||||
getattr(config, "hidden_size", None) is not None
|
||||
and getattr(config, "num_attention_heads", None) is not None
|
||||
):
|
||||
head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
|
||||
config.hidden_size *= max(16 // head_dim, 1)
|
||||
|
||||
if (
|
||||
getattr(config, "decoder_hidden_size", None) is not None
|
||||
and getattr(config, "decoder_num_attention_heads", None) is not None
|
||||
):
|
||||
decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
|
||||
config.decoder_hidden_size *= max(16 // decoder_head_dim, 1)
|
||||
|
||||
# Set default attention to flex and update config values
|
||||
update_config_for_flex(config)
|
||||
for key in config.sub_configs:
|
||||
sub_config = getattr(config, key)
|
||||
update_config_for_flex(sub_config)
|
||||
|
||||
config._attn_implementation = "flex_attention"
|
||||
model = model_class(config).to(device=torch_device)
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
|
||||
if config.is_encoder_decoder:
|
||||
for key in getattr(self, "additional_model_inputs", []):
|
||||
# Some models don't have all `additional_model_inputs`, especially when we
|
||||
# craft cases to test model in different settings
|
||||
if key in inputs_dict:
|
||||
dummy_inputs[key] = inputs_dict[key].to(torch_device)
|
||||
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
|
||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user