From 1ed19360b1400bd849164e0b9be940e8342af6b1 Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Fri, 23 May 2025 18:16:43 +0200 Subject: [PATCH] [`FlexAttention`] Reenable flex for encoder-decoder and make the test more robust (#38321) * reenable most flex attention test cases * style * trigger * trigger --- src/transformers/models/bart/modeling_bart.py | 3 +-- src/transformers/models/biogpt/modeling_biogpt.py | 3 +-- src/transformers/models/biogpt/modular_biogpt.py | 3 +-- .../models/blenderbot/modeling_blenderbot.py | 3 +-- .../blenderbot_small/modeling_blenderbot_small.py | 3 +-- .../models/data2vec/modeling_data2vec_audio.py | 3 +-- .../models/data2vec/modular_data2vec_audio.py | 3 +-- src/transformers/models/hubert/modeling_hubert.py | 3 +-- src/transformers/models/hubert/modular_hubert.py | 3 +-- .../models/m2m_100/modeling_m2m_100.py | 3 +-- src/transformers/models/marian/modeling_marian.py | 3 +-- src/transformers/models/mbart/modeling_mbart.py | 3 +-- .../models/musicgen/modeling_musicgen.py | 6 ++---- .../musicgen_melody/modeling_musicgen_melody.py | 6 ++---- .../models/pegasus/modeling_pegasus.py | 3 +-- .../models/pegasus_x/modeling_pegasus_x.py | 3 +-- src/transformers/models/plbart/modeling_plbart.py | 3 +-- src/transformers/models/plbart/modular_plbart.py | 3 +-- .../speech_to_text/modeling_speech_to_text.py | 1 - .../modeling_time_series_transformer.py | 1 - .../models/unispeech/modeling_unispeech.py | 3 +-- .../models/unispeech/modular_unispeech.py | 3 +-- .../unispeech_sat/modeling_unispeech_sat.py | 3 +-- .../models/unispeech_sat/modular_unispeech_sat.py | 3 +-- .../models/wav2vec2/modeling_wav2vec2.py | 3 +-- tests/test_modeling_common.py | 15 ++++++++++++--- 26 files changed, 37 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 60d9cdba2a..2442baa243 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -494,8 +494,7 @@ class BartPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 0b2a0dc274..f12eeac697 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -348,8 +348,7 @@ class BioGptPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 4bd675be92..78d6da134b 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -175,8 +175,7 @@ class BioGptPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index da7282d388..4c001a3544 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -464,8 +464,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 2237907aa0..49cff8f620 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -452,8 +452,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index d9046ea6e8..eafcbff89a 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -551,8 +551,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 73a42937bd..0b4695c1e2 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -140,8 +140,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index eb366963a6..115345407e 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -738,8 +738,7 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index 75000c95cb..c0454452f0 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -131,8 +131,7 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 55ecad4152..f348867272 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -530,8 +530,7 @@ class M2M100PreTrainedModel(PreTrainedModel): _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model _supports_static_cache = False diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 016cb865f8..a604820f2c 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -468,8 +468,7 @@ class MarianPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index bdf352a1f6..4f3253eeb4 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -498,8 +498,7 @@ class MBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"] _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index a0e21f586c..42b0567133 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -436,8 +436,7 @@ class MusicgenPreTrainedModel(PreTrainedModel): _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] _supports_flash_attn_2 = True _supports_sdpa = True - # compilation errors occurr atm - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): std = self.config.initializer_factor @@ -1361,8 +1360,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # compilation errors occurr atm - _supports_flex_attn = False + _supports_flex_attn = True def __init__( self, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 3312ad33cd..4e1ea39e75 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -410,8 +410,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _supports_flash_attn_2 = True _supports_sdpa = True - # compilation errors occurr atm - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): std = self.config.initializer_factor @@ -1292,8 +1291,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # compilation errors occurr atm - _supports_flex_attn = False + _supports_flex_attn = True def __init__( self, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 19166bd609..303ae89fd0 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -463,8 +463,7 @@ class PegasusPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index a2fcf5edd1..bf94379cca 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -764,8 +764,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True # Flaky logits _supports_sdpa = False - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 85813b0242..695a0ed458 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -78,8 +78,7 @@ class PLBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 8a9755b11b..1394e87f56 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -64,8 +64,7 @@ class PLBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 4acee66424..aa4ea81071 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -529,7 +529,6 @@ class Speech2TextPreTrainedModel(PreTrainedModel): # Current tests always assume certain inputs to be passed _supports_flash_attn_2 = False _supports_sdpa = False - # Compile issues _supports_flex_attn = False def _init_weights(self, module): diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 3bc19a75b3..dc960efbbc 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -636,7 +636,6 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): # Current tests always assume certain inputs to be passed _supports_flash_attn_2 = False _supports_sdpa = False - # Compile issues _supports_flex_attn = False def _init_weights(self, module): diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 07ee6608b7..4fdce328e9 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -847,8 +847,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index 795ab85967..5a9133089a 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -151,8 +151,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 8d9ac9c33f..50ee4c198d 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -850,8 +850,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index 9f9e7d4f3c..f86c397a04 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -161,8 +161,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index fb01234e3f..ae3510f175 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1096,8 +1096,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # Compile issues - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4ddbbcb47f..621ab67da0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4570,20 +4570,29 @@ class ModelTesterMixin: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config._attn_implementation = "flex_attention" - # Flex Attention can not use dropout + # Flex Attention cannot use dropout if hasattr(config, "attention_dropout"): config.attention_dropout = 0 if hasattr(config, "attention_probs_dropout_prob"): config.attention_probs_dropout_prob = 0 + # Flex attention relies on triton on compilation + # However, triton cannot handle hidden dimensions of less than 16 + # --> forcing at least a hidden dim of 16 + config.hidden_size *= max( + 16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1 + ) + if hasattr(config, "head_dim"): + config.head_dim = max(16, config.head_dim) + model = model_class(config).to(device=torch_device) self.assertTrue(model.config._attn_implementation == "flex_attention") # Elaborate workaround for encoder-decoder models as some do not specify their main input dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} if config.is_encoder_decoder: - dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"] - dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"] + dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device) + dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device) # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) _ = model(**dummy_inputs)