Fix the check in flex test (#39548)
* fix the check * fix flags * flags
This commit is contained in:
@@ -4599,16 +4599,10 @@ class ModelTesterMixin:
|
||||
model = model_class(config).to(device=torch_device)
|
||||
|
||||
# If not all sub-models support flex, skip the test
|
||||
sub_models_supporting_flex = [
|
||||
module._supports_flex_attn
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or (
|
||||
model._supports_flex_attn and len(sub_models_supporting_flex) == 0
|
||||
)
|
||||
if not supports_flex_all_modules:
|
||||
self.skipTest(reason="This model's submodels does not support flex attention")
|
||||
if not all(
|
||||
submodel._supports_flex_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel)
|
||||
):
|
||||
self.skipTest(reason="At least some parts of this model do not support flex attention")
|
||||
|
||||
def update_config_for_flex(config):
|
||||
# Flex Attention cannot use dropout
|
||||
@@ -4664,8 +4658,8 @@ class ModelTesterMixin:
|
||||
sub_config = getattr(config, key)
|
||||
update_config_for_flex(sub_config)
|
||||
|
||||
config._attn_implementation = "flex_attention"
|
||||
model = model_class(config).to(device=torch_device)
|
||||
model.set_attn_implementation("flex_attention")
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||
|
||||
Reference in New Issue
Block a user