[FlexAttention] Reenable flex for encoder-decoder and make the test more robust (#38321)
* reenable most flex attention test cases * style * trigger * trigger
This commit is contained in:
@@ -494,8 +494,7 @@ class BartPreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -348,8 +348,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -175,8 +175,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -464,8 +464,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -452,8 +452,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -551,8 +551,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -140,8 +140,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -738,8 +738,7 @@ class HubertPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -131,8 +131,7 @@ class HubertPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -530,8 +530,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
# Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model
|
||||
_supports_static_cache = False
|
||||
|
||||
@@ -468,8 +468,7 @@ class MarianPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -498,8 +498,7 @@ class MBartPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -436,8 +436,7 @@ class MusicgenPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# compilation errors occurr atm
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_factor
|
||||
@@ -1361,8 +1360,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# compilation errors occurr atm
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -410,8 +410,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# compilation errors occurr atm
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_factor
|
||||
@@ -1292,8 +1291,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# compilation errors occurr atm
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -463,8 +463,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -764,8 +764,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
# Flaky logits
|
||||
_supports_sdpa = False
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
|
||||
@@ -78,8 +78,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
||||
@@ -64,8 +64,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
||||
@@ -529,7 +529,6 @@ class Speech2TextPreTrainedModel(PreTrainedModel):
|
||||
# Current tests always assume certain inputs to be passed
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
||||
@@ -636,7 +636,6 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
|
||||
# Current tests always assume certain inputs to be passed
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
||||
@@ -847,8 +847,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -151,8 +151,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -850,8 +850,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -161,8 +161,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -1096,8 +1096,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
# Compile issues
|
||||
_supports_flex_attn = False
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -4570,20 +4570,29 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config._attn_implementation = "flex_attention"
|
||||
# Flex Attention can not use dropout
|
||||
# Flex Attention cannot use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
# Flex attention relies on triton on compilation
|
||||
# However, triton cannot handle hidden dimensions of less than 16
|
||||
# --> forcing at least a hidden dim of 16
|
||||
config.hidden_size *= max(
|
||||
16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1
|
||||
)
|
||||
if hasattr(config, "head_dim"):
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
|
||||
model = model_class(config).to(device=torch_device)
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
|
||||
if config.is_encoder_decoder:
|
||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"]
|
||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"]
|
||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
|
||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
|
||||
|
||||
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
|
||||
_ = model(**dummy_inputs)
|
||||
|
||||
Reference in New Issue
Block a user