Fix AriaForConditionalGeneration flex attn test (#36604)
AriaForConditionalGeneration depends on idefics3 vision transformer which does not support flex attn
This commit is contained in:
@@ -1380,6 +1380,7 @@ ARIA_START_DOCSTRING = r"""
|
||||
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
config_class = AriaConfig
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_flex_attn = False
|
||||
_supports_sdpa = False
|
||||
_tied_weights_keys = ["language_model.lm_head.weight"]
|
||||
|
||||
|
||||
@@ -1348,6 +1348,7 @@ ARIA_START_DOCSTRING = r"""
|
||||
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
config_class = AriaConfig
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_flex_attn = False
|
||||
_supports_sdpa = False
|
||||
_tied_weights_keys = ["language_model.lm_head.weight"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user