[tests] expand flex-attn test for vision models (#38434)
* expand the test for VLMs * typo * mark models `supports_flex` + expand test for additional kwargs * flex attn for refactored vision models * fix copies * fix * unskip * style * address comments
This commit is contained in:
committed by
GitHub
parent
de4cf5a38e
commit
bf68dd9e6e
@@ -3637,7 +3637,10 @@ class ModelTesterMixin:
|
||||
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
|
||||
|
||||
for key in getattr(self, "additional_model_inputs", []):
|
||||
processed_inputs[key] = inputs_dict[key]
|
||||
# Some models don't have all `additional_model_inputs`, especially when we
|
||||
# craft cases to test model in different settings
|
||||
if key in inputs_dict:
|
||||
processed_inputs[key] = inputs_dict[key]
|
||||
|
||||
for key, value in processed_inputs.items():
|
||||
if torch.is_floating_point(value):
|
||||
@@ -4012,19 +4015,21 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
|
||||
sub_models_supporting_fa2 = [
|
||||
(module._supports_flash_attn_2 or module._supports_attention_backend)
|
||||
module._supports_flash_attn_2
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_fa2_all_modules = (
|
||||
all(sub_models_supporting_fa2)
|
||||
if len(sub_models_supporting_fa2) > 0
|
||||
else (model._supports_flash_attn_2 or model._supports_attention_backend)
|
||||
else model._supports_flash_attn_2
|
||||
)
|
||||
if not supports_fa2_all_modules:
|
||||
with self.assertRaises(ValueError):
|
||||
model_fa2 = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
else:
|
||||
model_fa2 = model_class.from_pretrained(
|
||||
@@ -4572,33 +4577,73 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
def test_flex_attention_with_grads(self):
|
||||
for model_class in self.all_model_classes:
|
||||
# TODO: raushan, fix for composite models after making VLMs support new attn API
|
||||
if not model_class._supports_flex_attn or self._is_composite:
|
||||
self.skipTest(reason="This model does not support flex attention")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config._attn_implementation = "flex_attention"
|
||||
# Flex Attention cannot use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).to(device=torch_device)
|
||||
|
||||
# Flex attention relies on triton on compilation
|
||||
# However, triton cannot handle hidden dimensions of less than 16
|
||||
# --> forcing at least a hidden dim of 16
|
||||
config.hidden_size *= max(
|
||||
16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1
|
||||
# If not all sub-models support flex, skip the test
|
||||
sub_models_supporting_flex = [
|
||||
module._supports_flex_attn
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or (
|
||||
model._supports_flex_attn and len(sub_models_supporting_flex) == 0
|
||||
)
|
||||
if hasattr(config, "head_dim"):
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
if not supports_flex_all_modules:
|
||||
self.skipTest(reason="This model's submodels does not support flex attention")
|
||||
|
||||
def update_config_for_flex(config):
|
||||
# Flex Attention cannot use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
# Flex attention relies on triton on compilation
|
||||
# However, triton cannot handle hidden dimensions of less than 16
|
||||
# --> forcing at least a hidden dim of 16
|
||||
|
||||
# Update the head dim and try to update hidden size as well if present in config
|
||||
# NOTE: some models may have none if the values in sub-config, thus we check for `Noneness`
|
||||
head_dim = None
|
||||
if hasattr(config, "head_dim") and config.head_dim is not None:
|
||||
head_dim = config.head_dim
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
|
||||
if (
|
||||
getattr(config, "hidden_size", None) is not None
|
||||
and getattr(config, "num_attention_heads", None) is not None
|
||||
):
|
||||
head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
|
||||
config.hidden_size *= max(16 // head_dim, 1)
|
||||
|
||||
if (
|
||||
getattr(config, "decoder_hidden_size", None) is not None
|
||||
and getattr(config, "decoder_num_attention_heads", None) is not None
|
||||
):
|
||||
decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
|
||||
config.decoder_hidden_size *= max(16 // decoder_head_dim, 1)
|
||||
|
||||
# Set default attention to flex and update config values
|
||||
update_config_for_flex(config)
|
||||
for key in config.sub_configs:
|
||||
sub_config = getattr(config, key)
|
||||
update_config_for_flex(sub_config)
|
||||
|
||||
config._attn_implementation = "flex_attention"
|
||||
model = model_class(config).to(device=torch_device)
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
|
||||
if config.is_encoder_decoder:
|
||||
for key in getattr(self, "additional_model_inputs", []):
|
||||
# Some models don't have all `additional_model_inputs`, especially when we
|
||||
# craft cases to test model in different settings
|
||||
if key in inputs_dict:
|
||||
dummy_inputs[key] = inputs_dict[key].to(torch_device)
|
||||
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
|
||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user