[FlexAttention] Reenable flex for encoder-decoder and make the test more robust (#38321)

* reenable most flex attention test cases

* style

* trigger

* trigger
This commit is contained in:
Anton Vlasjuk
2025-05-23 18:16:43 +02:00
committed by GitHub
parent bb567d85a4
commit 1ed19360b1
26 changed files with 37 additions and 55 deletions

View File

@@ -494,8 +494,7 @@ class BartPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -348,8 +348,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -175,8 +175,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -464,8 +464,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -452,8 +452,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -551,8 +551,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -140,8 +140,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -738,8 +738,7 @@ class HubertPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -131,8 +131,7 @@ class HubertPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -530,8 +530,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
# Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model
_supports_static_cache = False _supports_static_cache = False

View File

@@ -468,8 +468,7 @@ class MarianPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -498,8 +498,7 @@ class MBartPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"] _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -436,8 +436,7 @@ class MusicgenPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# compilation errors occurr atm _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_factor std = self.config.initializer_factor
@@ -1361,8 +1360,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# compilation errors occurr atm _supports_flex_attn = True
_supports_flex_attn = False
def __init__( def __init__(
self, self,

View File

@@ -410,8 +410,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# compilation errors occurr atm _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_factor std = self.config.initializer_factor
@@ -1292,8 +1291,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# compilation errors occurr atm _supports_flex_attn = True
_supports_flex_attn = False
def __init__( def __init__(
self, self,

View File

@@ -463,8 +463,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -764,8 +764,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
# Flaky logits # Flaky logits
_supports_sdpa = False _supports_sdpa = False
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True

View File

@@ -78,8 +78,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std

View File

@@ -64,8 +64,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std

View File

@@ -529,7 +529,6 @@ class Speech2TextPreTrainedModel(PreTrainedModel):
# Current tests always assume certain inputs to be passed # Current tests always assume certain inputs to be passed
_supports_flash_attn_2 = False _supports_flash_attn_2 = False
_supports_sdpa = False _supports_sdpa = False
# Compile issues
_supports_flex_attn = False _supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):

View File

@@ -636,7 +636,6 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
# Current tests always assume certain inputs to be passed # Current tests always assume certain inputs to be passed
_supports_flash_attn_2 = False _supports_flash_attn_2 = False
_supports_sdpa = False _supports_sdpa = False
# Compile issues
_supports_flex_attn = False _supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):

View File

@@ -847,8 +847,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -151,8 +151,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -850,8 +850,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -161,8 +161,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -1096,8 +1096,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
# Compile issues _supports_flex_attn = True
_supports_flex_attn = False
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@@ -4570,20 +4570,29 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention" config._attn_implementation = "flex_attention"
# Flex Attention can not use dropout # Flex Attention cannot use dropout
if hasattr(config, "attention_dropout"): if hasattr(config, "attention_dropout"):
config.attention_dropout = 0 config.attention_dropout = 0
if hasattr(config, "attention_probs_dropout_prob"): if hasattr(config, "attention_probs_dropout_prob"):
config.attention_probs_dropout_prob = 0 config.attention_probs_dropout_prob = 0
# Flex attention relies on triton on compilation
# However, triton cannot handle hidden dimensions of less than 16
# --> forcing at least a hidden dim of 16
config.hidden_size *= max(
16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1
)
if hasattr(config, "head_dim"):
config.head_dim = max(16, config.head_dim)
model = model_class(config).to(device=torch_device) model = model_class(config).to(device=torch_device)
self.assertTrue(model.config._attn_implementation == "flex_attention") self.assertTrue(model.config._attn_implementation == "flex_attention")
# Elaborate workaround for encoder-decoder models as some do not specify their main input # Elaborate workaround for encoder-decoder models as some do not specify their main input
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
if config.is_encoder_decoder: if config.is_encoder_decoder:
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"] dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"] dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
_ = model(**dummy_inputs) _ = model(**dummy_inputs)