From e97d7a5be545d123862d09f6c32e8d90b98bbf0c Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 9 Jan 2025 20:03:33 +0100 Subject: [PATCH] add `_supports_flex_attn = True` for models that do support it (#35598) * add `_supports_flex_attn = True` * fix repo consistency --- src/transformers/models/aria/modeling_aria.py | 1 + src/transformers/models/cohere/modeling_cohere.py | 1 + src/transformers/models/cohere2/modeling_cohere2.py | 1 + src/transformers/models/gemma/modeling_gemma.py | 1 + src/transformers/models/gemma2/modeling_gemma2.py | 1 + src/transformers/models/glm/modeling_glm.py | 1 + src/transformers/models/granite/modeling_granite.py | 1 + src/transformers/models/llama/modeling_llama.py | 1 + src/transformers/models/mistral/modeling_mistral.py | 1 + src/transformers/models/mixtral/modeling_mixtral.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + src/transformers/models/olmo2/modeling_olmo2.py | 1 + src/transformers/models/olmoe/modeling_olmoe.py | 1 + src/transformers/models/phi/modeling_phi.py | 1 + src/transformers/models/phi3/modeling_phi3.py | 1 + src/transformers/models/phimoe/modeling_phimoe.py | 1 + src/transformers/models/qwen2/modeling_qwen2.py | 1 + src/transformers/models/starcoder2/modeling_starcoder2.py | 1 + 18 files changed, 18 insertions(+) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 739aa0af8d..12d7224b21 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -704,6 +704,7 @@ class AriaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 356d330007..ecb944a924 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -413,6 +413,7 @@ class CoherePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 8fa99c06a0..5e36b8a36c 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -413,6 +413,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 810fcd63e0..d9f08cc1ae 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -380,6 +380,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 8b995d1a08..fb60a611a4 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -410,6 +410,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index c6ea3a1d5f..aac6e471bd 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -395,6 +395,7 @@ class GlmPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index ef73b8015f..3c5c52524f 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -395,6 +395,7 @@ class GranitePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 00568a9737..16d4ba4dac 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -384,6 +384,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ec31cc41d2..a44f96dd7b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -356,6 +356,7 @@ class MistralPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1183fd4dbd..c7130a3a5f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -478,6 +478,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index b6742702c5..6cb2baa4e0 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -360,6 +360,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 1a256cf098..7219284bd8 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -361,6 +361,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index b8bc6f5de9..d4befc96d1 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -762,6 +762,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8f1867e4f4..63d40070c2 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -356,6 +356,7 @@ class PhiPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 66b01c6a0a..dd6d0d1dc3 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -422,6 +422,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index b8b05d9381..b540dd1830 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -910,6 +910,7 @@ class PhimoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 03f1748d01..e8d1bff8a2 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -369,6 +369,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 605f63b301..3668c076d2 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -360,6 +360,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True