From 3a152e3a5c633ece00b036cc23fe324ad0d2c196 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 21 Jul 2025 13:29:44 +0200 Subject: [PATCH] Fix the check in flex test (#39548) * fix the check * fix flags * flags --- .../models/colpali/modeling_colpali.py | 3 +++ .../models/colqwen2/modeling_colqwen2.py | 3 ++- .../models/colqwen2/modular_colqwen2.py | 3 +-- .../models/got_ocr2/modeling_got_ocr2.py | 7 +++---- .../models/got_ocr2/modular_got_ocr2.py | 4 ++++ tests/test_modeling_common.py | 16 +++++----------- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index 63ce8975e8..ad53d6a985 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -33,6 +33,9 @@ class ColPaliPreTrainedModel(PreTrainedModel): config: ColPaliConfig base_model_prefix = "model" _no_split_modules = [] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True def _init_weights(self, module): std = ( diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 6cbdaab123..684804ee37 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -41,8 +41,9 @@ class ColQwen2PreTrainedModel(PreTrainedModel): config: ColQwen2Config base_model_prefix = "model" _no_split_modules = [] - _supports_flash_attn = True _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True def _init_weights(self, module): std = ( diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index 3d2a57f1c3..0fcdc09b7b 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -226,8 +226,7 @@ class ColQwen2Processor(ColPaliProcessor): class ColQwen2PreTrainedModel(ColPaliPreTrainedModel): - _supports_flash_attn = True - _supports_sdpa = True + pass @dataclass diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 4ff79a53de..dc6f81945b 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -280,12 +280,11 @@ class GotOcr2PreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - - _supports_flash_attn = True - _supports_sdpa = True + _supports_flash_attn = False + _supports_sdpa = False _supports_static_cache = True - _supports_flex_attn = True + _supports_flex_attn = False _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 1ae6880753..06b9fca298 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -286,6 +286,10 @@ class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast): class GotOcr2PreTrainedModel(LlavaPreTrainedModel): + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False + def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 30138b0850..26f6a032d6 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4599,16 +4599,10 @@ class ModelTesterMixin: model = model_class(config).to(device=torch_device) # If not all sub-models support flex, skip the test - sub_models_supporting_flex = [ - module._supports_flex_attn - for name, module in model.named_modules() - if isinstance(module, PreTrainedModel) and name != "" - ] - supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or ( - model._supports_flex_attn and len(sub_models_supporting_flex) == 0 - ) - if not supports_flex_all_modules: - self.skipTest(reason="This model's submodels does not support flex attention") + if not all( + submodel._supports_flex_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel) + ): + self.skipTest(reason="At least some parts of this model do not support flex attention") def update_config_for_flex(config): # Flex Attention cannot use dropout @@ -4664,8 +4658,8 @@ class ModelTesterMixin: sub_config = getattr(config, key) update_config_for_flex(sub_config) - config._attn_implementation = "flex_attention" model = model_class(config).to(device=torch_device) + model.set_attn_implementation("flex_attention") self.assertTrue(model.config._attn_implementation == "flex_attention") # Elaborate workaround for encoder-decoder models as some do not specify their main input