Fix the check in flex test (#39548)

* fix the check

* fix flags

* flags
This commit is contained in:
Cyril Vallez
2025-07-21 13:29:44 +02:00
committed by GitHub
parent 78fb2d2760
commit 3a152e3a5c
6 changed files with 18 additions and 18 deletions

View File

@@ -33,6 +33,9 @@ class ColPaliPreTrainedModel(PreTrainedModel):
config: ColPaliConfig config: ColPaliConfig
base_model_prefix = "model" base_model_prefix = "model"
_no_split_modules = [] _no_split_modules = []
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
def _init_weights(self, module): def _init_weights(self, module):
std = ( std = (

View File

@@ -41,8 +41,9 @@ class ColQwen2PreTrainedModel(PreTrainedModel):
config: ColQwen2Config config: ColQwen2Config
base_model_prefix = "model" base_model_prefix = "model"
_no_split_modules = [] _no_split_modules = []
_supports_flash_attn = True
_supports_sdpa = True _supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
def _init_weights(self, module): def _init_weights(self, module):
std = ( std = (

View File

@@ -226,8 +226,7 @@ class ColQwen2Processor(ColPaliProcessor):
class ColQwen2PreTrainedModel(ColPaliPreTrainedModel): class ColQwen2PreTrainedModel(ColPaliPreTrainedModel):
_supports_flash_attn = True pass
_supports_sdpa = True
@dataclass @dataclass

View File

@@ -280,12 +280,11 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
base_model_prefix = "" base_model_prefix = ""
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn = False
_supports_flash_attn = True _supports_sdpa = False
_supports_sdpa = True
_supports_static_cache = True _supports_static_cache = True
_supports_flex_attn = True _supports_flex_attn = False
_supports_attention_backend = True _supports_attention_backend = True
def _init_weights(self, module): def _init_weights(self, module):

View File

@@ -286,6 +286,10 @@ class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast):
class GotOcr2PreTrainedModel(LlavaPreTrainedModel): class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
_supports_flash_attn = False
_supports_sdpa = False
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)

View File

@@ -4599,16 +4599,10 @@ class ModelTesterMixin:
model = model_class(config).to(device=torch_device) model = model_class(config).to(device=torch_device)
# If not all sub-models support flex, skip the test # If not all sub-models support flex, skip the test
sub_models_supporting_flex = [ if not all(
module._supports_flex_attn submodel._supports_flex_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel)
for name, module in model.named_modules() ):
if isinstance(module, PreTrainedModel) and name != "" self.skipTest(reason="At least some parts of this model do not support flex attention")
]
supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or (
model._supports_flex_attn and len(sub_models_supporting_flex) == 0
)
if not supports_flex_all_modules:
self.skipTest(reason="This model's submodels does not support flex attention")
def update_config_for_flex(config): def update_config_for_flex(config):
# Flex Attention cannot use dropout # Flex Attention cannot use dropout
@@ -4664,8 +4658,8 @@ class ModelTesterMixin:
sub_config = getattr(config, key) sub_config = getattr(config, key)
update_config_for_flex(sub_config) update_config_for_flex(sub_config)
config._attn_implementation = "flex_attention"
model = model_class(config).to(device=torch_device) model = model_class(config).to(device=torch_device)
model.set_attn_implementation("flex_attention")
self.assertTrue(model.config._attn_implementation == "flex_attention") self.assertTrue(model.config._attn_implementation == "flex_attention")
# Elaborate workaround for encoder-decoder models as some do not specify their main input # Elaborate workaround for encoder-decoder models as some do not specify their main input