Fix AriaForConditionalGeneration flex attn test (#36604)

AriaForConditionalGeneration depends on idefics3 vision transformer which does not support flex attn
This commit is contained in:
ivarflakstad
2025-03-11 11:05:49 +01:00
committed by GitHub
parent d126f35427
commit b1a51ea464
2 changed files with 2 additions and 0 deletions

View File

@@ -1380,6 +1380,7 @@ ARIA_START_DOCSTRING = r"""
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_flex_attn = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]

View File

@@ -1348,6 +1348,7 @@ ARIA_START_DOCSTRING = r"""
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_flex_attn = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]