From b1a51ea4642f120ed693a24671215aa3f9929dd5 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 11 Mar 2025 11:05:49 +0100 Subject: [PATCH] Fix AriaForConditionalGeneration flex attn test (#36604) AriaForConditionalGeneration depends on idefics3 vision transformer which does not support flex attn --- src/transformers/models/aria/modeling_aria.py | 1 + src/transformers/models/aria/modular_aria.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 05d5ba8997..252ddb694f 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1380,6 +1380,7 @@ ARIA_START_DOCSTRING = r""" class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): config_class = AriaConfig _supports_flash_attn_2 = False + _supports_flex_attn = False _supports_sdpa = False _tied_weights_keys = ["language_model.lm_head.weight"] diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index c62c074218..bf9d864a4c 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1348,6 +1348,7 @@ ARIA_START_DOCSTRING = r""" class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): config_class = AriaConfig _supports_flash_attn_2 = False + _supports_flex_attn = False _supports_sdpa = False _tied_weights_keys = ["language_model.lm_head.weight"]