From eb1a007f7f0bcff45b0b6d43759c583246946f91 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 23 Jul 2025 11:35:18 +0200 Subject: [PATCH] Rename `supports_static_cache` to `can_compile_fullgraph` (#39505) * update all * Apply suggestions from code review Co-authored-by: Joao Gante * apply suggestions * fix copies --------- Co-authored-by: Joao Gante --- .../modular-transformers/modeling_my_new_model2.py | 2 +- .../modular-transformers/modeling_new_task_model.py | 2 +- examples/modular-transformers/modeling_super.py | 2 +- src/transformers/generation/utils.py | 5 +++-- src/transformers/modeling_utils.py | 3 +-- src/transformers/models/arcee/modeling_arcee.py | 2 +- src/transformers/models/aria/modeling_aria.py | 2 +- src/transformers/models/aria/modular_aria.py | 2 +- .../models/aya_vision/modeling_aya_vision.py | 2 +- .../models/aya_vision/modular_aya_vision.py | 2 +- src/transformers/models/bart/modeling_bart.py | 2 +- .../bigbird_pegasus/modeling_bigbird_pegasus.py | 2 +- src/transformers/models/biogpt/modeling_biogpt.py | 2 +- src/transformers/models/biogpt/modular_biogpt.py | 2 +- src/transformers/models/bitnet/modeling_bitnet.py | 2 +- .../models/blenderbot/modeling_blenderbot.py | 2 +- .../blenderbot_small/modeling_blenderbot_small.py | 2 +- src/transformers/models/blip_2/modeling_blip_2.py | 2 +- src/transformers/models/bloom/modeling_bloom.py | 2 +- .../models/chameleon/modeling_chameleon.py | 2 +- src/transformers/models/codegen/modeling_codegen.py | 2 +- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/cohere2/modeling_cohere2.py | 2 +- src/transformers/models/csm/modeling_csm.py | 2 +- src/transformers/models/csm/modular_csm.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- .../modeling_decision_transformer.py | 2 +- .../models/deepseek_v2/modeling_deepseek_v2.py | 2 +- .../models/deepseek_v3/modeling_deepseek_v3.py | 2 +- src/transformers/models/dia/modeling_dia.py | 2 +- src/transformers/models/dia/modular_dia.py | 2 +- .../models/diffllama/modeling_diffllama.py | 2 +- src/transformers/models/doge/modeling_doge.py | 2 +- src/transformers/models/doge/modular_doge.py | 2 +- src/transformers/models/dots1/modeling_dots1.py | 2 +- src/transformers/models/emu3/modeling_emu3.py | 4 +--- src/transformers/models/emu3/modular_emu3.py | 2 -- .../models/ernie4_5/modeling_ernie4_5.py | 2 +- .../models/ernie4_5_moe/modeling_ernie4_5_moe.py | 2 +- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/gemma3/modeling_gemma3.py | 2 +- src/transformers/models/gemma3n/modeling_gemma3n.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/glm4/modeling_glm4.py | 2 +- .../models/glm4_moe/modeling_glm4_moe.py | 2 +- src/transformers/models/glm4_moe/modular_glm4_moe.py | 2 +- src/transformers/models/glm4v/modeling_glm4v.py | 2 +- .../models/got_ocr2/modeling_got_ocr2.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- .../models/gpt_neox/modeling_gpt_neox.py | 2 +- .../gpt_neox_japanese/modeling_gpt_neox_japanese.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- src/transformers/models/granite/modeling_granite.py | 2 +- .../models/granitemoe/modeling_granitemoe.py | 2 +- .../granitemoehybrid/modeling_granitemoehybrid.py | 2 +- .../granitemoeshared/modeling_granitemoeshared.py | 2 +- src/transformers/models/helium/modeling_helium.py | 2 +- src/transformers/models/idefics/modeling_idefics.py | 2 +- .../models/instructblip/modeling_instructblip.py | 4 ++-- .../instructblipvideo/modeling_instructblipvideo.py | 4 ++-- .../models/internvl/modeling_internvl.py | 2 +- src/transformers/models/janus/modeling_janus.py | 4 ++-- src/transformers/models/janus/modular_janus.py | 4 ++-- src/transformers/models/lfm2/modeling_lfm2.py | 2 +- src/transformers/models/lfm2/modular_lfm2.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/llama4/modeling_llama4.py | 2 +- src/transformers/models/llava/modeling_llava.py | 2 +- .../models/llava_next/modeling_llava_next.py | 2 +- .../llava_next_video/modeling_llava_next_video.py | 2 +- .../llava_onevision/modeling_llava_onevision.py | 2 +- src/transformers/models/longt5/modeling_longt5.py | 2 +- src/transformers/models/m2m_100/modeling_m2m_100.py | 2 +- src/transformers/models/marian/modeling_marian.py | 2 +- src/transformers/models/mbart/modeling_mbart.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 2 +- src/transformers/models/minimax/modeling_minimax.py | 3 +-- src/transformers/models/minimax/modular_minimax.py | 3 +-- src/transformers/models/mistral/modeling_mistral.py | 2 +- .../models/mistral3/modeling_mistral3.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/mixtral/modular_mixtral.py | 2 +- src/transformers/models/mllama/modeling_mllama.py | 4 ++-- .../modeling_modernbert_decoder.py | 2 +- .../modernbert_decoder/modular_modernbert_decoder.py | 2 +- .../models/moonshine/modeling_moonshine.py | 2 +- .../models/moonshine/modular_moonshine.py | 2 +- src/transformers/models/mt5/modeling_mt5.py | 2 +- .../models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/olmoe/modeling_olmoe.py | 2 +- src/transformers/models/opt/modeling_opt.py | 2 +- .../models/paligemma/modeling_paligemma.py | 2 +- src/transformers/models/pegasus/modeling_pegasus.py | 2 +- .../models/pegasus_x/modeling_pegasus_x.py | 2 +- .../models/perception_lm/modeling_perception_lm.py | 2 +- .../models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- .../phi4_multimodal/modeling_phi4_multimodal.py | 2 +- src/transformers/models/phimoe/modeling_phimoe.py | 2 +- .../models/pix2struct/modeling_pix2struct.py | 2 +- .../models/pop2piano/modeling_pop2piano.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2_5_omni/modeling_qwen2_5_omni.py | 2 +- .../models/qwen2_5_omni/modular_qwen2_5_omni.py | 2 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 2 +- src/transformers/models/qwen3/modeling_qwen3.py | 2 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 2 +- src/transformers/models/smollm3/modeling_smollm3.py | 2 +- .../models/stablelm/modeling_stablelm.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- .../modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- src/transformers/models/t5gemma/modeling_t5gemma.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- src/transformers/models/umt5/modeling_umt5.py | 2 +- .../models/video_llava/modeling_video_llava.py | 2 +- .../models/vipllava/modeling_vipllava.py | 2 +- src/transformers/models/voxtral/modeling_voxtral.py | 2 +- src/transformers/models/voxtral/modular_voxtral.py | 2 +- src/transformers/models/whisper/modeling_whisper.py | 2 +- src/transformers/utils/auto_docstring.py | 5 +++-- tests/generation/test_utils.py | 12 ++++++------ tests/test_modeling_common.py | 2 +- 130 files changed, 143 insertions(+), 148 deletions(-) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 19b059699e..e56eeec7d7 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -294,7 +294,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": MyNewModel2DecoderLayer, diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 9111883cfe..2a3df8e9c1 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -94,7 +94,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index fc90cce75a..ee90750cac 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -293,7 +293,7 @@ class SuperPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": SuperDecoderLayer, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index fe38e0d2b1..6f4adcfeb1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2059,7 +2059,7 @@ class GenerationMixin(ContinuousMixin): ) if generation_config.cache_implementation is not None: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if generation_config.cache_implementation == "static" and not self._supports_static_cache: + if generation_config.cache_implementation == "static" and not self._can_compile_fullgraph: raise ValueError( "This model does not support `cache_implementation='static'`. Please check the following " "issue: https://github.com/huggingface/transformers/issues/28981" @@ -2215,7 +2215,8 @@ class GenerationMixin(ContinuousMixin): using_compilable_cache = ( isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable ) - can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache + # TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile) + can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph # Exception 1: Some quantization methods do not support compilation if getattr(self, "hf_quantizer", None) is not None: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 71eb25f74a..dc4e997fef 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2063,8 +2063,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH # Flex Attention support _supports_flex_attn = False - # Has support `torch.compile(fullgraph=True)` - _supports_static_cache = False + _can_compile_fullgraph = False # A tensor parallel plan to be applied to the model when TP is enabled. For # top-level models, this attribute is currently defined in respective model diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index de6a102771..e288f63d71 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -313,7 +313,7 @@ class ArceePreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": ArceeDecoderLayer, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index f3d6df532a..9144cc6bdd 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -654,7 +654,7 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = True _can_record_outputs = { "hidden_states": AriaTextDecoderLayer, diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 1fb64f50b7..d980898460 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1302,7 +1302,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): class AriaPreTrainedModel(LlamaPreTrainedModel): config: AriaConfig base_model_prefix = "" - _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 70c85ab27f..df45633cc7 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -96,7 +96,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = False + _can_compile_fullgraph = False _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index 9dcf02547a..8e77762917 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -90,7 +90,7 @@ class AyaVisionMultiModalProjector(nn.Module): class AyaVisionPreTrainedModel(LlavaPreTrainedModel): - _supports_static_cache = False + _can_compile_fullgraph = False class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 02877c3d89..236a2f6471 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -493,7 +493,7 @@ class BartPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ac3636a048..4eeecb5577 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1565,7 +1565,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_param_buffer_assignment = False - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 8bf976caf5..d6ed401cd6 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -347,7 +347,7 @@ class BioGptPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 43717ae447..db5ad5dbbc 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -172,7 +172,7 @@ class BioGptPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 59bd7b2ef9..7bbfab8cdd 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -312,7 +312,7 @@ class BitNetPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": BitNetDecoderLayer, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 445ef06b0b..b4ec543b3e 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -458,7 +458,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index dc7eac4390..b248b2f0da 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -451,7 +451,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 4c7a52e6fb..b19ae2f8dc 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1831,7 +1831,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): config: Blip2Config main_input_name = "pixel_values" - _supports_static_cache = True + _can_compile_fullgraph = True _keep_in_fp32_modules = ["query_tokens", "qformer"] _supports_flash_attn = False # because self.qformer does not support FA2 diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index f999872bef..cc8cd4eae9 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -434,7 +434,7 @@ class BloomPreTrainedModel(PreTrainedModel): _no_split_modules = ["BloomBlock"] _skip_keys_device_placement = "past_key_values" - _supports_static_cache = True + _can_compile_fullgraph = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 1e43bbc6ec..b70387d595 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -815,7 +815,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_param_buffer_assignment = False _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index b1378f5517..3dbb6f5ecc 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -287,7 +287,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): _no_split_modules = ["CodeGenBlock"] _skip_keys_device_placement = "past_key_values" - _supports_static_cache = True + _can_compile_fullgraph = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 15934195ce..fc4314386b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -345,7 +345,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": CohereDecoderLayer, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 85faa46d17..88c3afe607 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -322,7 +322,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Cohere2DecoderLayer, diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 3cc477b6dc..f36b81f886 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -371,7 +371,7 @@ class CsmPreTrainedModel(PreTrainedModel): # does not because of Mimi codec model # _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": CsmDecoderLayer, diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 60b2612759..ad11589283 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -134,7 +134,7 @@ class CsmPreTrainedModel(PreTrainedModel): # does not because of Mimi codec model # _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": CsmDecoderLayer, diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 86b4944f08..ee5ec65f86 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -810,7 +810,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module: nn.Module): std = self.config.initializer_range diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 1b33296d7d..6a492e937a 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -453,7 +453,7 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): is_parallelizable = True supports_gradient_checkpointing = True - _supports_static_cache = False + _can_compile_fullgraph = False def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index e2cbf27af9..bffef42464 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -456,7 +456,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": DeepseekV2DecoderLayer, diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 98145e74eb..75cf107756 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -498,7 +498,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": DeepseekV3DecoderLayer, diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 927e808e9b..13e375e15e 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -67,7 +67,7 @@ class DiaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True main_input_name = "input_ids" _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 1c67a491ca..1bf4cbd5ef 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -62,7 +62,7 @@ class DiaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True main_input_name = "input_ids" _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 4660bfded7..5deec876f6 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -533,7 +533,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = False - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = False _can_record_outputs = { "hidden_states": DiffLlamaDecoderLayer, diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 58a18655b2..60e640c5f1 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -486,7 +486,7 @@ class DogePreTrainedModel(PreTrainedModel): _supports_flash_attn = False _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False + _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "router_logits": OutputRecorder(DogeCDMoE, index=1), diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index 49094ec499..f9b8154ab1 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -564,7 +564,7 @@ class DogeDecoderLayer(GradientCheckpointingLayer): class DogePreTrainedModel(LlamaPreTrainedModel): _supports_flash_attn = False - _supports_static_cache = False + _can_compile_fullgraph = False _can_record_outputs = { "router_logits": OutputRecorder(DogeCDMoE, index=1), "hidden_states": DogeDecoderLayer, diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 4a90048d7e..26fdc9f76c 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -418,7 +418,7 @@ class Dots1PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Dots1DecoderLayer, diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 570c32c2fb..182afe6b90 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1098,7 +1098,7 @@ class Emu3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_param_buffer_assignment = False _supports_flex_attn = True _supports_attention_backend = True @@ -1307,7 +1307,6 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): class Emu3Model(Emu3PreTrainedModel): _checkpoint_conversion_mapping = {"text_model.model": "text_model"} - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable def __init__(self, config): super().__init__(config) @@ -1450,7 +1449,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): "^vqmodel": "model.vqmodel", "^text_model.lm_head": "lm_head", } - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 9aa66fe65e..7bd59d1fae 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -889,7 +889,6 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): class Emu3Model(Emu3PreTrainedModel): _checkpoint_conversion_mapping = {"text_model.model": "text_model"} - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable def __init__(self, config): super().__init__(config) @@ -1032,7 +1031,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): "^vqmodel": "model.vqmodel", "^text_model.lm_head": "lm_head", } - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 0a9a170184..de575baf79 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -311,7 +311,7 @@ class Ernie4_5PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Ernie4_5DecoderLayer, diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 9081f66265..74671bb33f 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -473,7 +473,7 @@ class Ernie4_5_MoEPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { "router_logits": OutputRecorder(Ernie4_5_MoESparseMoeBlock, index=1), diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 4033e2c14d..5cd2bd5058 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -643,7 +643,7 @@ class FalconPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e4b2e5429c..4b1fefd2cd 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -310,7 +310,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": GemmaDecoderLayer, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 459ae67fd4..2ce51042ed 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -343,7 +343,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Gemma2DecoderLayer, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 42ec9410b9..307b671844 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -434,7 +434,7 @@ class Gemma3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Gemma3DecoderLayer, diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index edd6717291..3c304bbcf6 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1490,7 +1490,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Gemma3nTextDecoderLayer, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 3b8b2d86c9..6dd31884e2 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -327,7 +327,7 @@ class GlmPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": GlmDecoderLayer, diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 8888055c37..55b3ecfd7a 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -331,7 +331,7 @@ class Glm4PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Glm4DecoderLayer, diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index 93706258dc..19f0a6b2e5 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -403,7 +403,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False + _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "hidden_states": Glm4MoeDecoderLayer, diff --git a/src/transformers/models/glm4_moe/modular_glm4_moe.py b/src/transformers/models/glm4_moe/modular_glm4_moe.py index 509fc39d39..1f6628d938 100644 --- a/src/transformers/models/glm4_moe/modular_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modular_glm4_moe.py @@ -310,7 +310,7 @@ class Glm4MoeDecoderLayer(DeepseekV3DecoderLayer): class Glm4MoePreTrainedModel(DeepseekV3PreTrainedModel): - _supports_static_cache = False + _can_compile_fullgraph = False class Glm4MoeModel(DeepseekV3Model): diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 57d99a1f98..25ccd9f10f 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -407,7 +407,7 @@ class Glm4vPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index dead80a503..464d54f819 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -283,7 +283,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _supports_flash_attn = False _supports_sdpa = False - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = False _supports_attention_backend = True diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index c853d80e4a..80442af911 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -563,7 +563,7 @@ class GPT2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_attention_backend = True - _supports_static_cache = True + _can_compile_fullgraph = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7d655bd0e6..89e6f7182a 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -477,7 +477,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True - _supports_static_cache = False # TODO: needs a HybridCache + _can_compile_fullgraph = False # TODO: needs a HybridCache def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index ec9efd340c..216ba439ca 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -364,7 +364,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": GPTNeoXDecoderLayer, diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 9e1859e794..e80f788023 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -48,7 +48,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoXJapaneseLayer"] _skip_keys_device_placement = "past_key_values" - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 9f8c08f96f..0622bf5ed0 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -472,7 +472,7 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_param_buffer_assignment = False def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 58a9402bb6..5d76b63b81 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -309,7 +309,7 @@ class GranitePreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": GraniteDecoderLayer, diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 8a51c28053..5ea293e5bf 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -592,7 +592,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): super()._init_weights(module) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 598673586c..ab31709f3d 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1212,7 +1212,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 9afb268a49..b10369e767 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -510,7 +510,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): super()._init_weights(module) diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 91c35ccf6c..3aac2621e8 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -312,7 +312,7 @@ class HeliumPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": HeliumDecoderLayer, diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 741886c142..ac8b7776c5 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -880,7 +880,7 @@ class IdeficsPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flash_attn = True - _supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs + _can_compile_fullgraph = False # IDEFICS cannot compile due to dynamic control flow when checking inputs _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index b88c003660..bcafeeec1e 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -340,7 +340,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _no_split_modules = [ "InstructBlipQFormerEmbeddings", @@ -1354,7 +1354,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati config: InstructBlipConfig main_input_name = "pixel_values" - _supports_static_cache = True + _can_compile_fullgraph = True _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 def __init__(self, config: InstructBlipConfig): diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index cec9198253..8e098183e2 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -827,7 +827,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _no_split_modules = [ "InstructBlipVideoQFormerEmbeddings", @@ -1360,7 +1360,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel config: InstructBlipVideoConfig main_input_name = "pixel_values" - _supports_static_cache = True + _can_compile_fullgraph = True _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 def __init__(self, config: InstructBlipVideoConfig): diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 55b9b6dcfe..8e1c616700 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -512,7 +512,7 @@ class InternVLPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 0553326af1..b93e1a8b67 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -63,7 +63,7 @@ class JanusPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_param_buffer_assignment = False @@ -1105,7 +1105,7 @@ class JanusModel(JanusPreTrainedModel): class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] - _supports_static_cache = True + _can_compile_fullgraph = True def __init__(self, config: JanusConfig): super().__init__(config) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 69b9293374..29accd88e5 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -391,7 +391,7 @@ class JanusPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_param_buffer_assignment = False @@ -965,7 +965,7 @@ class JanusModel(JanusPreTrainedModel): class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] - _supports_static_cache = True + _can_compile_fullgraph = True def __init__(self, config: JanusConfig): super().__init__(config) diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 548a62aa85..5a60fed7eb 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -578,7 +578,7 @@ class Lfm2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False + _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "hidden_states": Lfm2DecoderLayer, diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 9ffeb95513..c3c39e4677 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -437,7 +437,7 @@ class Lfm2DecoderLayer(GradientCheckpointingLayer): class Lfm2PreTrainedModel(LlamaPreTrainedModel): - _supports_static_cache = False + _can_compile_fullgraph = False class Lfm2Model(LlamaModel): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 758f0ab61a..11bcb93f86 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -317,7 +317,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": LlamaDecoderLayer, diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 85aeb70ce3..53d9367b7c 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -437,7 +437,7 @@ class Llama4PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 56f056023b..92199c9505 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -121,7 +121,7 @@ class LlavaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 7cbad1b980..94f03925b8 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -232,7 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 7721d760ea..dce37a4dd9 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -173,7 +173,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index ea5ca1e5ea..41c39d26f3 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -286,7 +286,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 71e5178ef5..87badf1ad9 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1250,7 +1250,7 @@ class LongT5PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LongT5Block"] - _supports_static_cache = False # TODO: @raushan more involved due to local/global attn + _can_compile_fullgraph = False # TODO: @raushan more involved due to local/global attn @property # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 59641c6084..ccd7c000ac 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -525,7 +525,7 @@ class M2M100PreTrainedModel(PreTrainedModel): _supports_flex_attn = True # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model - _supports_static_cache = False + _can_compile_fullgraph = False def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 8a3e7965f5..80ad444075 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -467,7 +467,7 @@ class MarianPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): std = self.config.init_std diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 5aafbc1de6..2f6b5c20ef 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -492,7 +492,7 @@ class MBartPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 0be377794d..260ea6f7ce 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1376,7 +1376,7 @@ class MimiPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 4e6ce12c22..1469b6d242 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -581,8 +581,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - # Note: only supports MiniMaxCache - _supports_static_cache = False + _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1), diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 6844a9d0fc..99be8f7fb5 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -468,8 +468,7 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer): class MiniMaxPreTrainedModel(MixtralPreTrainedModel): - # Note: only supports MiniMaxCache - _supports_static_cache = False + _can_compile_fullgraph = False _can_record_outputs = { "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1), "hidden_states": MiniMaxDecoderLayer, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4981cdce5b..caf3681f14 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -256,7 +256,7 @@ class MistralPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": MistralDecoderLayer, diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index cf46657627..bc61bc55b1 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -186,7 +186,7 @@ class Mistral3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1b364e0c69..043862a3a2 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -383,7 +383,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1), diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index de02a2a833..c4a7b5b2df 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -277,7 +277,7 @@ class MixtralRotaryEmbedding(MistralRotaryEmbedding): class MixtralPreTrainedModel(MistralPreTrainedModel): - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _can_record_outputs = { "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1), "hidden_states": MixtralDecoderLayer, diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index fbc8b287b6..266a916cef 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -850,7 +850,7 @@ class MllamaPreTrainedModel(PreTrainedModel): "MllamaSelfAttentionDecoderLayer", ] - _supports_static_cache = False # static cache cannot have different shapes for each layer + _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True @@ -1449,7 +1449,7 @@ class MllamaTextModel(MllamaPreTrainedModel): ) class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): config: MllamaTextConfig - _supports_static_cache = True # only the LLM without cross attn can do compile + _can_compile_fullgraph = True # only the LLM without cross attn can do compile base_model_prefix = "language_model" _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index c3d90771e1..0b1d572a1f 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -224,7 +224,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): _supports_flash_attn = True _supports_sdpa = False _supports_gradient_checkpointing = True - _supports_static_cache = False + _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "hidden_states": ModernBertDecoderLayer, diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index d215ccbf0b..3b6b936e15 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -401,7 +401,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): _supports_flash_attn = True _supports_sdpa = False _supports_gradient_checkpointing = True - _supports_static_cache = False + _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "hidden_states": ModernBertDecoderLayer, diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index c17fc883c1..15598ccb32 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -462,7 +462,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True # TODO arthur, how do we separate when it cross / self coming from different layer? def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 6278fab115..326cd743ce 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -497,7 +497,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True # TODO arthur, how do we separate when it cross / self coming from different layer? def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 0d2ed6b402..071010abf1 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -757,7 +757,7 @@ class MT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _supports_static_cache = True + _can_compile_fullgraph = True _no_split_modules = ["MT5Block"] _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index c0885df873..714b71f6bb 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -589,7 +589,7 @@ class NemotronPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 52b97bc64e..987611bca2 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -294,7 +294,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": OlmoDecoderLayer, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 3ba2114dc4..d113a26462 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -299,7 +299,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Olmo2DecoderLayer, diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 8caff4d9ab..38b538fdf3 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -706,7 +706,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 2fe9a14677..ff5e8dfa01 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -313,7 +313,7 @@ class OPTPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 581269653c..2d82dccc18 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -114,7 +114,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index d37ee0007c..33fc066d89 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -458,7 +458,7 @@ class PegasusPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index e2e04804c0..029ad0a2e4 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -758,7 +758,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): _supports_sdpa = False _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index 37e1bed147..65f35c7951 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -95,7 +95,7 @@ class PerceptionLMPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 92affb08eb..2779b2a504 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -390,7 +390,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): _no_split_modules = ["PersimmonDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_static_cache = True + _can_compile_fullgraph = True _supports_sdpa = True _supports_flash_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 5ee828dbf5..4cf3d54a65 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -298,7 +298,7 @@ class PhiPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": PhiDecoderLayer, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 67852e8fc9..399e207a42 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -287,7 +287,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Phi3DecoderLayer, diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index a984a7957c..c0376fc0e2 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1582,7 +1582,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Phi4MultimodalDecoderLayer, diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 886c9a7d84..3b1369a924 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -887,7 +887,7 @@ class PhimoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 820d59bad3..8ef1ea8090 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -351,7 +351,7 @@ class Pix2StructVisionEncoder(nn.Module): class Pix2StructPreTrainedModel(PreTrainedModel): config: Pix2StructConfig - _supports_static_cache = False + _can_compile_fullgraph = False @property def dummy_inputs(self): diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index bdf712a2f9..71a1397cc3 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -577,7 +577,7 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): is_parallelizable = False supports_gradient_checkpointing = True - _supports_static_cache = False + _can_compile_fullgraph = False _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index bbe8fe54cd..5aaaae52fc 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -259,7 +259,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Qwen2DecoderLayer, diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index b8da2d7846..f690ca5108 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -67,7 +67,7 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = False + _can_compile_fullgraph = False _supports_attention_backend = True diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index f899d0f8c4..b8e5bc6216 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1133,7 +1133,7 @@ class Qwen2_5OmniConfig(PretrainedConfig): class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): config: Qwen2_5OmniConfig - _supports_static_cache = False + _can_compile_fullgraph = False class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index d653b9de87..c270e2714c 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -326,7 +326,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 6c17e0c9c4..068199e6d9 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -660,7 +660,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index aafaebf7cf..695df11f37 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -285,7 +285,7 @@ class Qwen3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Qwen3DecoderLayer, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 230ac9a68c..e7080f9d1b 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -406,7 +406,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { "router_logits": OutputRecorder(Qwen3MoeSparseMoeBlock, index=1), diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 2af412933b..082cc792e0 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -289,7 +289,7 @@ class SmolLM3PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": SmolLM3DecoderLayer, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index c728d9cdbc..b70eaf1a09 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -621,7 +621,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 6f1114b31a..cce7b5cebb 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -293,7 +293,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Starcoder2DecoderLayer, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index f22484047b..d6241ff134 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -766,7 +766,7 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True - _supports_static_cache = False + _can_compile_fullgraph = False _no_split_modules = ["SwitchTransformersBlock"] @property diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e1570a731d..e39c4b2f99 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -771,7 +771,7 @@ class T5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _supports_static_cache = True + _can_compile_fullgraph = True _no_split_modules = ["T5Block"] _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 9f1ad6928b..e2b563cbac 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -585,7 +585,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": T5GemmaDecoderLayer, diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 817ea913ae..a84d15051b 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -255,7 +255,7 @@ class UdopPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True - _supports_static_cache = False + _can_compile_fullgraph = False _keep_in_fp32_modules = ["wo"] def _init_weights(self, module): diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index cf2e999449..e171be26fe 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -508,7 +508,7 @@ class UMT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True - _supports_static_cache = True + _can_compile_fullgraph = True _no_split_modules = ["UMT5Block"] _keep_in_fp32_modules = ["wo"] diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index aea7256868..befa350b90 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -135,7 +135,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 2f3f536f79..df3635b690 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -122,7 +122,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index b2350310a8..ae949a4579 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -236,7 +236,7 @@ class VoxtralPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_cache_class = True _supports_attention_backend = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): # important: this ported version of Voxtral isn't meant for training from scratch - only diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index fdb9862ad5..a3cb8c3ed0 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -47,7 +47,7 @@ class VoxtralPreTrainedModel(Qwen2AudioPreTrainedModel): _supports_flex_attn = True _supports_cache_class = True _supports_attention_backend = True - _supports_static_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 917113cbb0..ad8ac7cee3 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -553,7 +553,7 @@ class WhisperPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - _supports_static_cache = True + _can_compile_fullgraph = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index f277df1af1..c1e2f8e1eb 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -965,8 +965,9 @@ class ClassAttrs: _supports_flex_attn = r""" Whether the model's attention implementation supports FlexAttention. """ - _supports_static_cache = r""" - Whether the model supports a `StaticCache` instance as `past_key_values`. + _can_compile_fullgraph = r""" + Whether the model can `torch.compile` fullgraph without graph breaks. Models will auto-compile if this flag is set to `True` + in inference, if a compilable cache is used. """ _supports_attention_backend = r""" Whether the model supports attention interface functions. This flag signal that the model can be used as an efficient backend in TGI and vLLM. diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 26a5a7bf3d..5f5f33de35 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1764,7 +1764,7 @@ class GenerationTesterMixin: to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache. """ for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: + if not model_class._can_compile_fullgraph: self.skipTest(reason="This model does not support the static cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() @@ -1984,7 +1984,7 @@ class GenerationTesterMixin: """ set_model_tester_for_less_flaky_test(self) for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: + if not model_class._can_compile_fullgraph: self.skipTest(reason="This model does not support the static cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() @@ -2087,8 +2087,8 @@ class GenerationTesterMixin: set_model_tester_for_less_flaky_test(self) for model_class in self.all_generative_model_classes: # 1. Test exclusion criteria - if not model_class._supports_static_cache: - self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") + if not model_class._can_compile_fullgraph: + self.skipTest("This model doesn't support compilation without graph breaks") # 2. Prepares two sets of inputs config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4) @@ -2201,8 +2201,8 @@ class GenerationTesterMixin: In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered. """ for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: - self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") + if not model_class._can_compile_fullgraph: + self.skipTest("This model doesn't support compilation without graph breaks") config, inputs_dict = self.prepare_config_and_inputs_for_generate() if self.has_attentions: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9c4c0da4ee..f490108817 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4415,7 +4415,7 @@ class ModelTesterMixin: set_model_tester_for_less_flaky_test(self) for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: + if not model_class._can_compile_fullgraph: self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") config, _ = self.model_tester.prepare_config_and_inputs_for_common() set_config_for_less_flaky_test(config)