Fix the check in flex test (#39548)

* fix the check

* fix flags

* flags
This commit is contained in:
Cyril Vallez
2025-07-21 13:29:44 +02:00
committed by GitHub
parent 78fb2d2760
commit 3a152e3a5c
6 changed files with 18 additions and 18 deletions

View File

@@ -4599,16 +4599,10 @@ class ModelTesterMixin:
model = model_class(config).to(device=torch_device)
# If not all sub-models support flex, skip the test
sub_models_supporting_flex = [
module._supports_flex_attn
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or (
model._supports_flex_attn and len(sub_models_supporting_flex) == 0
)
if not supports_flex_all_modules:
self.skipTest(reason="This model's submodels does not support flex attention")
if not all(
submodel._supports_flex_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel)
):
self.skipTest(reason="At least some parts of this model do not support flex attention")
def update_config_for_flex(config):
# Flex Attention cannot use dropout
@@ -4664,8 +4658,8 @@ class ModelTesterMixin:
sub_config = getattr(config, key)
update_config_for_flex(sub_config)
config._attn_implementation = "flex_attention"
model = model_class(config).to(device=torch_device)
model.set_attn_implementation("flex_attention")
self.assertTrue(model.config._attn_implementation == "flex_attention")
# Elaborate workaround for encoder-decoder models as some do not specify their main input