Fix the check in flex test (#39548)
* fix the check * fix flags * flags
This commit is contained in:
@@ -33,6 +33,9 @@ class ColPaliPreTrainedModel(PreTrainedModel):
|
|||||||
config: ColPaliConfig
|
config: ColPaliConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
_no_split_modules = []
|
_no_split_modules = []
|
||||||
|
_supports_sdpa = True
|
||||||
|
_supports_flash_attn = True
|
||||||
|
_supports_flex_attn = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = (
|
std = (
|
||||||
|
|||||||
@@ -41,8 +41,9 @@ class ColQwen2PreTrainedModel(PreTrainedModel):
|
|||||||
config: ColQwen2Config
|
config: ColQwen2Config
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
_no_split_modules = []
|
_no_split_modules = []
|
||||||
_supports_flash_attn = True
|
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_flash_attn = True
|
||||||
|
_supports_flex_attn = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = (
|
std = (
|
||||||
|
|||||||
@@ -226,8 +226,7 @@ class ColQwen2Processor(ColPaliProcessor):
|
|||||||
|
|
||||||
|
|
||||||
class ColQwen2PreTrainedModel(ColPaliPreTrainedModel):
|
class ColQwen2PreTrainedModel(ColPaliPreTrainedModel):
|
||||||
_supports_flash_attn = True
|
pass
|
||||||
_supports_sdpa = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -280,12 +280,11 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_flash_attn = False
|
||||||
_supports_flash_attn = True
|
_supports_sdpa = False
|
||||||
_supports_sdpa = True
|
|
||||||
|
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
_supports_flex_attn = True
|
_supports_flex_attn = False
|
||||||
_supports_attention_backend = True
|
_supports_attention_backend = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
|
|||||||
@@ -286,6 +286,10 @@ class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast):
|
|||||||
|
|
||||||
|
|
||||||
class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
|
class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
|
||||||
|
_supports_flash_attn = False
|
||||||
|
_supports_sdpa = False
|
||||||
|
_supports_flex_attn = False
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||||
|
|
||||||
|
|||||||
@@ -4599,16 +4599,10 @@ class ModelTesterMixin:
|
|||||||
model = model_class(config).to(device=torch_device)
|
model = model_class(config).to(device=torch_device)
|
||||||
|
|
||||||
# If not all sub-models support flex, skip the test
|
# If not all sub-models support flex, skip the test
|
||||||
sub_models_supporting_flex = [
|
if not all(
|
||||||
module._supports_flex_attn
|
submodel._supports_flex_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel)
|
||||||
for name, module in model.named_modules()
|
):
|
||||||
if isinstance(module, PreTrainedModel) and name != ""
|
self.skipTest(reason="At least some parts of this model do not support flex attention")
|
||||||
]
|
|
||||||
supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or (
|
|
||||||
model._supports_flex_attn and len(sub_models_supporting_flex) == 0
|
|
||||||
)
|
|
||||||
if not supports_flex_all_modules:
|
|
||||||
self.skipTest(reason="This model's submodels does not support flex attention")
|
|
||||||
|
|
||||||
def update_config_for_flex(config):
|
def update_config_for_flex(config):
|
||||||
# Flex Attention cannot use dropout
|
# Flex Attention cannot use dropout
|
||||||
@@ -4664,8 +4658,8 @@ class ModelTesterMixin:
|
|||||||
sub_config = getattr(config, key)
|
sub_config = getattr(config, key)
|
||||||
update_config_for_flex(sub_config)
|
update_config_for_flex(sub_config)
|
||||||
|
|
||||||
config._attn_implementation = "flex_attention"
|
|
||||||
model = model_class(config).to(device=torch_device)
|
model = model_class(config).to(device=torch_device)
|
||||||
|
model.set_attn_implementation("flex_attention")
|
||||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||||
|
|
||||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||||
|
|||||||
Reference in New Issue
Block a user