[FlexAttention] Reenable flex for encoder-decoder and make the test more robust (#38321)
* reenable most flex attention test cases * style * trigger * trigger
This commit is contained in:
@@ -494,8 +494,7 @@ class BartPreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -348,8 +348,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -175,8 +175,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -464,8 +464,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -452,8 +452,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -551,8 +551,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -140,8 +140,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -738,8 +738,7 @@ class HubertPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -131,8 +131,7 @@ class HubertPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -530,8 +530,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
|
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
# Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model
|
# Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model
|
||||||
_supports_static_cache = False
|
_supports_static_cache = False
|
||||||
|
|||||||
@@ -468,8 +468,7 @@ class MarianPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -498,8 +498,7 @@ class MBartPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"]
|
_no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -436,8 +436,7 @@ class MusicgenPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
|
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# compilation errors occurr atm
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_factor
|
std = self.config.initializer_factor
|
||||||
@@ -1361,8 +1360,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# compilation errors occurr atm
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -410,8 +410,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
|
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# compilation errors occurr atm
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_factor
|
std = self.config.initializer_factor
|
||||||
@@ -1292,8 +1291,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# compilation errors occurr atm
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -463,8 +463,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -764,8 +764,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
# Flaky logits
|
# Flaky logits
|
||||||
_supports_sdpa = False
|
_supports_sdpa = False
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|
||||||
|
|||||||
@@ -78,8 +78,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
|
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
|
|||||||
@@ -64,8 +64,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
|
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
|
|||||||
@@ -529,7 +529,6 @@ class Speech2TextPreTrainedModel(PreTrainedModel):
|
|||||||
# Current tests always assume certain inputs to be passed
|
# Current tests always assume certain inputs to be passed
|
||||||
_supports_flash_attn_2 = False
|
_supports_flash_attn_2 = False
|
||||||
_supports_sdpa = False
|
_supports_sdpa = False
|
||||||
# Compile issues
|
|
||||||
_supports_flex_attn = False
|
_supports_flex_attn = False
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
|
|||||||
@@ -636,7 +636,6 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
|
|||||||
# Current tests always assume certain inputs to be passed
|
# Current tests always assume certain inputs to be passed
|
||||||
_supports_flash_attn_2 = False
|
_supports_flash_attn_2 = False
|
||||||
_supports_sdpa = False
|
_supports_sdpa = False
|
||||||
# Compile issues
|
|
||||||
_supports_flex_attn = False
|
_supports_flex_attn = False
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
|
|||||||
@@ -847,8 +847,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -151,8 +151,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -850,8 +850,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -161,8 +161,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -1096,8 +1096,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
# Compile issues
|
_supports_flex_attn = True
|
||||||
_supports_flex_attn = False
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
|
|||||||
@@ -4570,20 +4570,29 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config._attn_implementation = "flex_attention"
|
config._attn_implementation = "flex_attention"
|
||||||
# Flex Attention can not use dropout
|
# Flex Attention cannot use dropout
|
||||||
if hasattr(config, "attention_dropout"):
|
if hasattr(config, "attention_dropout"):
|
||||||
config.attention_dropout = 0
|
config.attention_dropout = 0
|
||||||
if hasattr(config, "attention_probs_dropout_prob"):
|
if hasattr(config, "attention_probs_dropout_prob"):
|
||||||
config.attention_probs_dropout_prob = 0
|
config.attention_probs_dropout_prob = 0
|
||||||
|
|
||||||
|
# Flex attention relies on triton on compilation
|
||||||
|
# However, triton cannot handle hidden dimensions of less than 16
|
||||||
|
# --> forcing at least a hidden dim of 16
|
||||||
|
config.hidden_size *= max(
|
||||||
|
16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1
|
||||||
|
)
|
||||||
|
if hasattr(config, "head_dim"):
|
||||||
|
config.head_dim = max(16, config.head_dim)
|
||||||
|
|
||||||
model = model_class(config).to(device=torch_device)
|
model = model_class(config).to(device=torch_device)
|
||||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||||
|
|
||||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||||
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
|
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"]
|
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
|
||||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"]
|
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
|
||||||
|
|
||||||
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
|
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
|
||||||
_ = model(**dummy_inputs)
|
_ = model(**dummy_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user