From 8d6259b0b8290c2406949ce6342051b1f09a074c Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 15 Jul 2025 12:34:06 +0500 Subject: [PATCH] [refactor] set attention implementation (#38974) * update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC --- src/transformers/modeling_utils.py | 405 +++++++----------- .../models/aimv2/modeling_aimv2.py | 3 +- .../models/aimv2/modular_aimv2.py | 3 +- .../models/arcee/modeling_arcee.py | 3 +- src/transformers/models/aria/modeling_aria.py | 5 +- src/transformers/models/aria/modular_aria.py | 2 +- .../modeling_audio_spectrogram_transformer.py | 3 +- .../models/aya_vision/modeling_aya_vision.py | 3 +- .../models/bamba/modeling_bamba.py | 3 +- .../models/bamba/modular_bamba.py | 3 +- src/transformers/models/bark/modeling_bark.py | 39 +- src/transformers/models/bart/modeling_bart.py | 3 +- .../models/biogpt/modeling_biogpt.py | 3 +- .../models/biogpt/modular_biogpt.py | 3 +- .../models/bitnet/modeling_bitnet.py | 3 +- .../models/blenderbot/modeling_blenderbot.py | 3 +- .../modeling_blenderbot_small.py | 3 +- .../models/blip_2/modeling_blip_2.py | 15 +- .../models/chameleon/modeling_chameleon.py | 3 +- src/transformers/models/clip/modeling_clip.py | 3 +- .../models/cohere/modeling_cohere.py | 3 +- .../models/cohere2/modeling_cohere2.py | 3 +- .../models/colqwen2/modeling_colqwen2.py | 3 +- .../models/colqwen2/modular_colqwen2.py | 3 +- src/transformers/models/csm/modeling_csm.py | 3 +- src/transformers/models/csm/modular_csm.py | 3 +- .../data2vec/modeling_data2vec_audio.py | 3 +- .../models/data2vec/modular_data2vec_audio.py | 3 +- src/transformers/models/dbrx/modeling_dbrx.py | 3 +- .../deepseek_v2/modeling_deepseek_v2.py | 3 +- .../deepseek_v3/modeling_deepseek_v3.py | 3 +- src/transformers/models/deit/modeling_deit.py | 3 +- src/transformers/models/dia/modeling_dia.py | 3 +- src/transformers/models/dia/modular_dia.py | 3 +- .../models/diffllama/modeling_diffllama.py | 3 +- .../models/dinov2/modeling_dinov2.py | 3 +- .../modeling_dinov2_with_registers.py | 3 +- .../models/distilbert/modeling_distilbert.py | 3 +- src/transformers/models/doge/modeling_doge.py | 3 +- src/transformers/models/doge/modular_doge.py | 3 +- .../models/dots1/modeling_dots1.py | 3 +- src/transformers/models/dpt/modeling_dpt.py | 3 +- src/transformers/models/emu3/modeling_emu3.py | 6 +- src/transformers/models/emu3/modular_emu3.py | 3 +- .../modeling_encoder_decoder.py | 3 +- src/transformers/models/eomt/modeling_eomt.py | 3 +- src/transformers/models/eomt/modular_eomt.py | 3 +- src/transformers/models/esm/modeling_esm.py | 3 +- .../models/esm/modeling_esmfold.py | 2 +- .../models/falcon/modeling_falcon.py | 3 +- .../models/falcon_h1/modeling_falcon_h1.py | 3 +- .../models/falcon_h1/modular_falcon_h1.py | 3 +- src/transformers/models/fuyu/modeling_fuyu.py | 3 +- .../models/gemma/modeling_gemma.py | 3 +- .../models/gemma2/modeling_gemma2.py | 3 +- .../models/gemma3/modeling_gemma3.py | 3 +- .../models/gemma3n/modeling_gemma3n.py | 3 +- src/transformers/models/glm/modeling_glm.py | 3 +- src/transformers/models/glm4/modeling_glm4.py | 3 +- .../models/glm4v/modeling_glm4v.py | 3 +- .../models/got_ocr2/modeling_got_ocr2.py | 3 +- src/transformers/models/gpt2/modeling_gpt2.py | 3 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 3 +- .../models/gpt_neo/modeling_gpt_neo.py | 3 +- .../models/gpt_neox/modeling_gpt_neox.py | 3 +- src/transformers/models/gptj/modeling_gptj.py | 3 +- .../models/granite/modeling_granite.py | 3 +- .../granite_speech/modeling_granite_speech.py | 3 +- .../models/granitemoe/modeling_granitemoe.py | 3 +- .../modeling_granitemoehybrid.py | 3 +- .../modeling_granitemoeshared.py | 3 +- .../models/helium/modeling_helium.py | 3 +- .../models/hubert/modeling_hubert.py | 3 +- .../models/hubert/modular_hubert.py | 3 +- .../models/idefics/modeling_idefics.py | 3 +- .../models/idefics2/modeling_idefics2.py | 6 +- .../models/idefics3/modeling_idefics3.py | 6 +- .../models/ijepa/modeling_ijepa.py | 3 +- .../models/ijepa/modular_ijepa.py | 3 +- .../instructblip/modeling_instructblip.py | 5 +- .../modeling_instructblipvideo.py | 5 +- .../models/internvl/modeling_internvl.py | 6 +- .../models/internvl/modular_internvl.py | 3 +- .../models/jamba/modeling_jamba.py | 3 +- .../models/janus/modeling_janus.py | 3 +- .../models/janus/modular_janus.py | 3 +- .../models/jetmoe/modeling_jetmoe.py | 3 +- .../models/kosmos2/modeling_kosmos2.py | 3 +- .../modeling_kyutai_speech_to_text.py | 3 +- src/transformers/models/lfm2/modeling_lfm2.py | 3 +- .../models/lightglue/modeling_lightglue.py | 3 +- .../models/lightglue/modular_lightglue.py | 3 +- .../models/llama/modeling_llama.py | 3 +- .../models/llama4/modeling_llama4.py | 2 +- .../models/llava/modeling_llava.py | 3 +- .../models/llava_next/modeling_llava_next.py | 3 +- .../modeling_llava_next_video.py | 3 +- .../modeling_llava_onevision.py | 3 +- .../models/m2m_100/modeling_m2m_100.py | 3 +- .../models/marian/modeling_marian.py | 3 +- .../models/mbart/modeling_mbart.py | 3 +- src/transformers/models/mimi/modeling_mimi.py | 3 +- .../models/minimax/modeling_minimax.py | 3 +- .../models/mistral/modeling_mistral.py | 3 +- .../models/mistral3/modeling_mistral3.py | 3 +- .../models/mixtral/modeling_mixtral.py | 3 +- src/transformers/models/mlcd/modeling_mlcd.py | 3 +- src/transformers/models/mlcd/modular_mlcd.py | 3 +- .../models/mllama/modeling_mllama.py | 3 +- .../models/modernbert/modeling_modernbert.py | 44 +- .../models/modernbert/modular_modernbert.py | 44 +- .../models/moonshine/modeling_moonshine.py | 3 +- .../models/moonshine/modular_moonshine.py | 3 +- .../models/moshi/modeling_moshi.py | 6 +- .../models/musicgen/modeling_musicgen.py | 6 +- .../modeling_musicgen_melody.py | 6 +- .../models/nemotron/modeling_nemotron.py | 3 +- .../models/nllb_moe/modeling_nllb_moe.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 3 +- .../models/olmo2/modeling_olmo2.py | 3 +- .../models/olmoe/modeling_olmoe.py | 3 +- src/transformers/models/opt/modeling_opt.py | 3 +- .../models/paligemma/modeling_paligemma.py | 3 +- .../models/pegasus/modeling_pegasus.py | 3 +- .../models/pegasus_x/modeling_pegasus_x.py | 3 +- .../perception_lm/modeling_perception_lm.py | 3 +- .../models/persimmon/modeling_persimmon.py | 3 +- src/transformers/models/phi/modeling_phi.py | 3 +- src/transformers/models/phi3/modeling_phi3.py | 3 +- .../modeling_phi4_multimodal.py | 9 +- .../modular_phi4_multimodal.py | 6 +- .../models/phimoe/modeling_phimoe.py | 3 +- .../models/pixtral/modeling_pixtral.py | 6 +- .../models/plbart/modeling_plbart.py | 3 +- .../models/plbart/modular_plbart.py | 3 +- .../models/qwen2/modeling_qwen2.py | 3 +- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 3 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 3 +- .../qwen2_audio/modeling_qwen2_audio.py | 3 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 3 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 3 +- .../models/qwen3/modeling_qwen3.py | 3 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 3 +- src/transformers/models/rag/modeling_rag.py | 3 +- .../modeling_recurrent_gemma.py | 2 +- src/transformers/models/sew/modeling_sew.py | 3 +- src/transformers/models/sew/modular_sew.py | 3 +- .../models/siglip/modeling_siglip.py | 3 +- .../models/siglip2/modeling_siglip2.py | 3 +- .../models/smollm3/modeling_smollm3.py | 3 +- .../models/smolvlm/modeling_smolvlm.py | 6 +- .../modeling_speech_encoder_decoder.py | 3 +- .../speech_to_text/modeling_speech_to_text.py | 2 +- .../models/stablelm/modeling_stablelm.py | 3 +- .../models/starcoder2/modeling_starcoder2.py | 3 +- .../models/t5gemma/modeling_t5gemma.py | 3 +- .../modeling_time_series_transformer.py | 2 +- .../models/unispeech/modeling_unispeech.py | 3 +- .../models/unispeech/modular_unispeech.py | 3 +- .../unispeech_sat/modeling_unispeech_sat.py | 3 +- .../unispeech_sat/modular_unispeech_sat.py | 3 +- .../video_llava/modeling_video_llava.py | 3 +- .../models/videomae/modeling_videomae.py | 3 +- .../models/vipllava/modeling_vipllava.py | 3 +- .../modeling_vision_encoder_decoder.py | 3 +- .../modeling_vision_text_dual_encoder.py | 3 +- src/transformers/models/vit/modeling_vit.py | 3 +- .../models/vit_mae/modeling_vit_mae.py | 3 +- .../models/vit_msn/modeling_vit_msn.py | 3 +- .../modeling_vitpose_backbone.py | 3 +- .../models/vivit/modeling_vivit.py | 3 +- .../models/vjepa2/modeling_vjepa2.py | 3 +- .../models/wav2vec2/modeling_wav2vec2.py | 3 +- .../models/wavlm/modeling_wavlm.py | 3 +- .../models/wavlm/modular_wavlm.py | 2 +- .../models/whisper/modeling_whisper.py | 3 +- .../models/yolos/modeling_yolos.py | 3 +- .../models/zamba/modeling_zamba.py | 26 +- .../models/zamba2/modeling_zamba2.py | 3 +- .../models/zamba2/modular_zamba2.py | 3 +- src/transformers/utils/args_doc.py | 7 +- tests/generation/test_utils.py | 24 +- tests/models/blip_2/test_modeling_blip_2.py | 9 +- .../test_modeling_instructblipvideo.py | 7 +- tests/test_modeling_common.py | 51 ++- 185 files changed, 451 insertions(+), 776 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7153257a10..8e10c8eef5 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1961,11 +1961,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi supports_gradient_checkpointing = False _is_stateful = False - # Flash Attention 2 support - _supports_flash_attn_2 = False - - # Flash Attention 3 support - _supports_flash_attn_3 = False + # Flash Attention support + _supports_flash_attn = False # SDPA support _supports_sdpa = False @@ -2074,12 +2071,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "`PretrainedConfig`. To create a model from a pretrained model use " f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" ) - if not getattr(config, "_attn_implementation_autoset", False): - # config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests - dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype() - config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False) self.config = config + # The `hasattr` here is used as some Transformers tests for some reason do not call + # PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + if hasattr(config, "_attn_implementation_internal") and not getattr( + config, "_attn_implementation_autoset", False + ): + self.set_attention_implementation(self.config._attn_implementation_internal) + # for initialization of the loss loss_type = self.__class__.__name__ if loss_type not in LOSS_MAPPING: @@ -2226,19 +2226,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. if config._attn_implementation_internal is not None: - # In this case, the config has been created with the attn_implementation set by the user, which we - # should respect. + # In this case, the config has been created with the attn_implementation set by the user, which we should respect. attn_implementation = config._attn_implementation_internal else: attn_implementation = None - config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation) - if not getattr(config, "_attn_implementation_autoset", False): - config = cls._autoset_attn_implementation( - config, - check_device_map=False, - torch_dtype=torch_dtype, - ) if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called: logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") @@ -2260,81 +2252,65 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi return model @classmethod - def _autoset_attn_implementation( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, dict[str, int]]] = None, - check_device_map: bool = True, - ): + def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]: """ - Automatically checks and dispatches to a default attention implementation. In order of priority: - 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). - 2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) - 3. The default model's implementation otherwise (`LlamaAttention` for example) . + Checks that the requested attention implementation exists and tries to get the kernel from hub + if `attn_implementation` matches hf kernels pattern. """ - # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user. - # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). - # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) - requested_attn_implementation = None - if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: - if isinstance(config._attn_implementation, str) and re.match( - r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation - ): - if not is_kernels_available(): - raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") + if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation): + if not is_kernels_available(): + raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") - # Extract repo_id and kernel_name from the string - repo_id, kernel_name = config._attn_implementation.split(":") - kernel_name = kernel_name.strip() - repo_id = repo_id.strip() + # Extract repo_id and kernel_name from the string + repo_id, kernel_name = attn_implementation.split(":") + kernel_name = kernel_name.strip() + repo_id = repo_id.strip() - try: - kernel = get_kernel(repo_id) - ALL_ATTENTION_FUNCTIONS.register( - f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name) - ) - config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}" - except FileNotFoundError as e: - logger.warning( - f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead." - ) - config._attn_implementation = "eager" - except AttributeError: - raise ValueError( - "the kernel function name or class specified in the attn_implementation argument is not valid. \ - Please check the documentation for the correct format, \ - and check that the kernel exports the class and the function correctly." - ) + try: + kernel = get_kernel(repo_id) + ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)) + attn_implementation = f"kernel_{repo_id.replace('/', '_')}" + except FileNotFoundError as e: + logger.warning( + f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead." + ) + attn_implementation = None # try to dispatch SDPA and fallback eager if not available + except AttributeError: + raise ValueError( + "the kernel function name or class specified in the attn_implementation argument is not valid. \ + Please check the documentation for the correct format, \ + and check that the kernel exports the class and the function correctly." + ) + if ( + not isinstance(attn_implementation, dict) + and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys() + ): + message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' + # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases + if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False): + message += ( + ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' + ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + ) + if cls._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + if cls._supports_flex_attn: + message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' + raise ValueError(message + ".") - if ( - not isinstance(config._attn_implementation, dict) - and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() - ): - message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' - if cls._supports_flash_attn_3: - message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' - if cls._supports_flash_attn_2: - message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' - if cls._supports_sdpa: - message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' - if cls._supports_flex_attn: - message += ( - ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' - ) - raise ValueError(message + ".") + return attn_implementation - # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. - requested_attn_implementation = config._attn_implementation_internal + def set_attention_implementation(self, attn_implementation: Union[str, dict]): + """ + Checks and dispatches to the requested attention implementation. + """ + requested_attn_implementation = self._check_attn_implementation(attn_implementation) - # Composite models consisting of several PretrainedModels have to specify attention impl as a dict - # where keys are sub-config names. But most people will specify one `str` which means that should dispatch it - # for all sub-models. - # Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict. - # Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)` - # If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238 - for key in config.sub_configs.keys(): - sub_config = getattr(config, key) + # Composite models consisting of several PretrainedModels can specify attention implementation as a dict where + # keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models. + # See https://github.com/huggingface/transformers/pull/32238 + for key in self.config.sub_configs.keys(): + sub_config = getattr(self.config, key) curr_attn_implementation = ( requested_attn_implementation if not isinstance(requested_attn_implementation, dict) @@ -2349,50 +2325,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi ): sub_config._attn_implementation_internal = curr_attn_implementation - if config._attn_implementation == "flash_attention_3": - cls._check_and_enable_flash_attn_3( - config, - torch_dtype=torch_dtype, - device_map=device_map, - hard_check_only=False, - check_device_map=check_device_map, - ) - elif config._attn_implementation == "flash_attention_2": - cls._check_and_enable_flash_attn_2( - config, - torch_dtype=torch_dtype, - device_map=device_map, - hard_check_only=False, - check_device_map=check_device_map, - ) - elif requested_attn_implementation == "flex_attention": - config = cls._check_and_enable_flex_attn(config, hard_check_only=True) - elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): - # flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. - config = cls._check_and_enable_sdpa( - config, - hard_check_only=requested_attn_implementation is not None, - ) - - if ( - torch.version.hip is not None - and config._attn_implementation == "sdpa" - and torch.cuda.device_count() > 1 - and version.parse(torch.__version__) < version.parse("2.4.1") - ): - logger.warning_once( - "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." - ) - torch.backends.cuda.enable_flash_sdp(False) + if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch(): + self.config._attn_implementation = "flash_attention_3" + if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch(): + self.config._attn_implementation = "flash_attention_2" + elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch(): + self.config._attn_implementation = "flex_attention" + elif ( + requested_attn_implementation in [None, "sdpa"] + and not is_torch_xla_available() + and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None) + ): + self.config._attn_implementation = "sdpa" elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys(): - config._attn_implementation = requested_attn_implementation + self.config._attn_implementation = requested_attn_implementation elif isinstance(requested_attn_implementation, dict): - config._attn_implementation = None + self.config._attn_implementation = requested_attn_implementation.get("", None) else: - config._attn_implementation = "eager" + self.config._attn_implementation = "eager" - config._attn_implementation_autoset = True - return config + self.config._attn_implementation_autoset = True @classmethod def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: @@ -2466,24 +2418,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Otherwise, can't generate return False - @classmethod - def _check_and_enable_flash_attn_2( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, dict[str, int]]] = None, - check_device_map: bool = True, - hard_check_only: bool = False, - ) -> PretrainedConfig: + def _flash_attn_2_can_dispatch(self) -> bool: """ Checks the availability of Flash Attention 2 and compatibility with the current model. If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. """ - if not cls._supports_flash_attn_2: + # Config always has `torch_dtype` but we need the next line for `no_super_init()` tests + torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype() + device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None + + # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases + if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)): raise ValueError( - f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where" - f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + f"{self.__class__.__name__} does not support Flash Attention 2.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new" " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" ) @@ -2491,39 +2440,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." - if importlib.util.find_spec("flash_attn") is None: - # package `flash-attn` can not be installed on Ascend NPU, ignore related validation logic and early exit. - if is_torch_npu_available(): - if not hard_check_only: - config._attn_implementation = "flash_attention_2" - - logger.info("Detect using FlashAttention2 on Ascend NPU.") - return config - else: - raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") - - flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) - if torch.version.cuda: - if flash_attention_version < version.parse("2.1.0"): - raise ImportError( - f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" - ) - elif not torch.cuda.is_available(): - raise ValueError( - f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device." - ) - else: - raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") - elif torch.version.hip: - if flash_attention_version < version.parse("2.0.4"): - raise ImportError( - f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}" - ) - else: - raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") - - _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + # package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi + if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available(): + raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + else: + # Check FA2 installed version compatibility + flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + if torch.version.cuda: + if flash_attention_version < version.parse("2.1.0"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" + ) + elif not torch.cuda.is_available(): + raise ValueError( + f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device." + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + elif torch.version.hip: + if flash_attention_version < version.parse("2.0.4"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Detected version {flash_attention_version}. {install_message}" + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + _is_bettertransformer = getattr(self, "use_bettertransformer", False) if _is_bettertransformer: raise ValueError( "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" @@ -2536,13 +2478,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: logger.warning_once( "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but" - f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`' ) # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, # or the model may be initialized under the context manager `with torch.device("cuda"):`. - if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]: + if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]: if torch.cuda.is_available(): logger.warning_once( "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU" @@ -2560,8 +2502,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "or initialising the model on CPU and then moving it to GPU." ) elif ( - check_device_map - and device_map is not None + device_map is not None and isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()) ): @@ -2569,28 +2510,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." ) - if not hard_check_only: - config._attn_implementation = "flash_attention_2" - return config - @classmethod - def _check_and_enable_flash_attn_3( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, dict[str, int]]] = None, - check_device_map: bool = True, - hard_check_only: bool = False, - ) -> PretrainedConfig: + # If no error raise by this point, we can return `True` + return True + + def _flash_attn_3_can_dispatch(self) -> bool: """ Checks the availability of Flash Attention 3 and compatibility with the current model. If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. """ - if not cls._supports_flash_attn_3: + # Config always has `torch_dtype` but we need the next line for `no_super_init()` tests + torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype() + device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None + + if not self._supports_flash_attn: raise ValueError( - f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" - f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + f"{self.__class__.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new" " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" ) @@ -2620,22 +2557,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: logger.warning_once( "Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but" - f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`' ) - if getattr(config, "alibi", False) or getattr(config, "use_alibi", False): + if getattr(self.config, "alibi", False) or getattr(self.config, "use_alibi", False): raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.") # Check for attention dropout, which is incompatible with FA3 - if hasattr(config, "attention_dropout") and config.attention_dropout > 0: + if hasattr(self.config, "attention_dropout") and self.config.attention_dropout > 0: raise ValueError( - f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3." + f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3." ) # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, # or the model may be initialized under the context manager `with torch.device("cuda"):`. - if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]: + if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]: if torch.cuda.is_available(): logger.warning_once( "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU" @@ -2648,8 +2585,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "or initialising the model on CPU and then moving it to GPU." ) elif ( - check_device_map - and device_map is not None + device_map is not None and isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()) ): @@ -2657,21 +2593,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to " "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." ) - if not hard_check_only: - config._attn_implementation = "flash_attention_3" - return config + return True - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): + def _sdpa_can_dispatch(self, hard_check_only: bool = False) -> bool: """ Checks the availability of SDPA for a given model. If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module. """ if hard_check_only: - if not cls._supports_sdpa: + if not self._supports_sdpa: raise ValueError( - f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." + f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' ) @@ -2680,45 +2613,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1." ) - if not is_torch_sdpa_available() or not cls._supports_sdpa: - return config + if ( + torch.version.hip is not None + and torch.cuda.device_count() > 1 + and version.parse(torch.__version__) < version.parse("2.4.1") + ): + logger.warning_once( + "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." + ) + torch.backends.cuda.enable_flash_sdp(False) - _is_bettertransformer = getattr(cls, "use_bettertransformer", False) - if _is_bettertransformer: - return config + # This means we have `hard_check_only=False` and fallback to eager if SDPA isn't supported + _is_bettertransformer = getattr(self, "use_bettertransformer", False) + if not is_torch_sdpa_available() or not self._supports_sdpa or _is_bettertransformer: + return False - if not hard_check_only: - config._attn_implementation = "sdpa" - return config + return True - @classmethod - def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + def _flex_attn_can_dispatch(self) -> bool: """ Checks the availability of Flex Attention for a given model. If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module. """ - if hard_check_only: - if not cls._supports_flex_attn: - raise ValueError( - f"{cls.__name__} does not support an attention implementation through torch's flex_attention." - " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809." - " If you believe this error is a bug, please open an issue in Transformers GitHub repository" - ' and load your model with the argument `attn_implementation="eager"` meanwhile.' - ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' - ) - if not is_torch_flex_attn_available(): - raise ImportError( - "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0." - ) + if not self._supports_flex_attn: + raise ValueError( + f"{self.__class__.__name__} does not support an attention implementation through torch's flex_attention." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809." + " If you believe this error is a bug, please open an issue in Transformers GitHub repository" + ' and load your model with the argument `attn_implementation="eager"` meanwhile.' + ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_torch_flex_attn_available(): + raise ImportError( + "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0." + ) - if not is_torch_flex_attn_available() or not cls._supports_flex_attn: - return config - - if not hard_check_only: - config._attn_implementation = "flex_attention" - - return config + # If no error raise by this point, we can return `True` + return True def enable_input_require_grads(self): """ @@ -4803,13 +4735,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. - if not getattr(config, "_attn_implementation_autoset", False): - config = cls._autoset_attn_implementation( - config, - torch_dtype=torch_dtype, - device_map=device_map, - ) - with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index d9434e610b..c0d9f0990c 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -444,8 +444,7 @@ class Aimv2PreTrainedModel(PreTrainedModel): "Aimv2TextEmbeddings", ] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True def _init_weights(self, module): diff --git a/src/transformers/models/aimv2/modular_aimv2.py b/src/transformers/models/aimv2/modular_aimv2.py index 25156fbd1a..703c42e308 100644 --- a/src/transformers/models/aimv2/modular_aimv2.py +++ b/src/transformers/models/aimv2/modular_aimv2.py @@ -441,8 +441,7 @@ class Aimv2PreTrainedModel(PreTrainedModel): "Aimv2TextEmbeddings", ] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True def _init_weights(self, module): diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index a7902c75ec..8e1b1b168b 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -313,8 +313,7 @@ class ArceePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ArceeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 386346ae2a..a42d04717e 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -629,7 +629,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"] supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = True _supports_cache_class = True _supports_attention_backend = True @@ -661,8 +661,7 @@ class AriaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["AriaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index aeebd1dd09..fa371db78c 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1284,7 +1284,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"] supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = True _supports_cache_class = True _supports_attention_backend = True diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index a6baf63814..f37a5151fe 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -375,8 +375,7 @@ class ASTPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 599bc63b11..ccd8d3a56a 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -94,8 +94,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = False _supports_static_cache = False diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 1210de68df..0113ce6b8c 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1039,8 +1039,7 @@ class BambaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index a80796cf6e..9c61066479 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -810,8 +810,7 @@ class BambaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 362f9d2ba7..a4201faca5 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -355,8 +355,7 @@ class BarkBlock(GradientCheckpointingLayer): class BarkPreTrainedModel(PreTrainedModel): config_class = BarkConfig supports_gradient_checkpointing = False - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True def _init_weights(self, module): """Initialize the weights.""" @@ -1684,42 +1683,6 @@ class BarkModel(BarkPreTrainedModel): return audio - @classmethod - def _check_and_enable_flash_attn_2( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, dict[str, int]]] = None, - hard_check_only: bool = False, - check_device_map: bool = False, - ): - """ - `_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model - sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention - if necessary. - - If you don't know about Flash Attention, check out the official repository of flash attention: - https://github.com/Dao-AILab/flash-attention - - For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this - specific section of the documentation to learn more about it: - https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models - - The method checks if the current setup is compatible with Flash Attention as it requires the model to be in - half precision and not ran on CPU. - - If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model - can initialize the correct attention module - """ - config = super()._check_and_enable_flash_attn_2( - config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map - ) - - config.semantic_config._attn_implementation = config._attn_implementation - config.coarse_acoustics_config._attn_implementation = config._attn_implementation - config.fine_acoustics_config._attn_implementation = config._attn_implementation - return config - __all__ = [ "BarkFineModel", diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7d768a7348..f2729321f0 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -493,8 +493,7 @@ class BartPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a9fa3584d4..390b57a417 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -346,8 +346,7 @@ class BioGptPreTrainedModel(PreTrainedModel): config_class = BioGptConfig base_model_prefix = "biogpt" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 3fc8205f75..24d1c77fb6 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -171,8 +171,7 @@ class BioGptPreTrainedModel(PreTrainedModel): config_class = BioGptConfig base_model_prefix = "biogpt" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index ff3a17998f..c2e84df256 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -308,8 +308,7 @@ class BitNetPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BitNetDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 659b856a77..2af2d99448 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -463,8 +463,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): config_class = BlenderbotConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index f82a0d3222..f81718d68c 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -451,8 +451,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): config_class = BlenderbotSmallConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index f141246701..21c533dd4f 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -409,8 +409,7 @@ class Blip2PreTrainedModel(PreTrainedModel): base_model_prefix = "blip" supports_gradient_checkpointing = True _supports_attention_backend = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True @@ -1048,7 +1047,7 @@ class Blip2TextEmbeddings(nn.Module): ) class Blip2QFormerModel(Blip2PreTrainedModel): _supports_attention_backend = False # adds position on attn weights before last matmul - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False @@ -1238,7 +1237,7 @@ class Blip2Model(Blip2PreTrainedModel): config_class = Blip2Config main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] - _supports_flash_attn_2 = False # because self.qformer does not support FA2 + _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) @@ -1623,7 +1622,7 @@ class Blip2Model(Blip2PreTrainedModel): class Blip2TextModelWithProjection(Blip2PreTrainedModel): supports_gradient_checkpointing = False _keep_in_fp32_modules = ["query_tokens", "qformer"] - _supports_flash_attn_2 = False # because self.qformer does not support FA2 + _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) @@ -1716,7 +1715,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel): class Blip2VisionModelWithProjection(Blip2PreTrainedModel): main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] - _supports_flash_attn_2 = False # because self.qformer does not support FA2 + _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) @@ -1836,7 +1835,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens", "qformer"] - _supports_flash_attn_2 = False # because self.qformer does not support FA2 + _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) @@ -2267,7 +2266,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] - _supports_flash_attn_2 = False # because self.qformer does not support FA2 + _supports_flash_attn = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 4f699914b9..0100f191d2 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -821,8 +821,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index e453c10339..eb697fcccf 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -428,8 +428,7 @@ class CLIPPreTrainedModel(PreTrainedModel): base_model_prefix = "clip" supports_gradient_checkpointing = True _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index f457b4e8ac..6e43f257a3 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -341,8 +341,7 @@ class CoherePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["CohereDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index d2cb22132b..82ac1b2b61 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -318,8 +318,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Cohere2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 65bfa7d921..0a64bcc60e 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -41,8 +41,7 @@ class ColQwen2PreTrainedModel(PreTrainedModel): config_class = ColQwen2Config base_model_prefix = "model" _no_split_modules = [] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index 08b79e247e..4f5ce4aa8a 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -225,8 +225,7 @@ class ColQwen2Processor(ColPaliProcessor): class ColQwen2PreTrainedModel(ColPaliPreTrainedModel): - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 7c18216eeb..f0f7eecbad 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -366,8 +366,7 @@ class CsmPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["CsmDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True # does not because of Mimi codec model # _supports_flex_attn = True diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index f6098993c3..266327b13e 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -129,8 +129,7 @@ class CsmPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["CsmDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True # does not because of Mimi codec model # _supports_flex_attn = True diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index bd3cc865ac..94dedbfb38 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -505,8 +505,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): base_model_prefix = "data2vec_audio" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 2de816e80e..314e08ed4e 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -139,8 +139,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): base_model_prefix = "data2vec_audio" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 84c72d9ac2..3bef3e3293 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -809,8 +809,7 @@ class DbrxPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DbrxBlock"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 47862f30c7..836b9dc9de 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -455,8 +455,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index d686c8297b..6c00b64eef 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -494,8 +494,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 90b678466c..52ca6bdac9 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -442,8 +442,7 @@ class DeiTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DeiTLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 38bc0715fc..f801a7f603 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -64,8 +64,7 @@ class DiaPreTrainedModel(PreTrainedModel): config_class = DiaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 18512856dd..5dfa78ce36 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -59,8 +59,7 @@ class DiaPreTrainedModel(PreTrainedModel): config_class = DiaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 1e19c226e2..26de946643 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -530,8 +530,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DiffLlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = False _supports_cache_class = True diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 6b98e3fa8d..e266727bf9 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -486,8 +486,7 @@ class Dinov2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Dinov2Layer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index 7d37f00daa..69236ab67f 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -504,8 +504,7 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Dinov2WithRegistersLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index a047a676be..feb8d6d8bc 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -575,8 +575,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): load_tf_weights = None base_model_prefix = "distilbert" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module: nn.Module): diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 8aaf546481..969150c7c7 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -491,8 +491,7 @@ class DogePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DogeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False - _supports_flash_attn_3 = False + _supports_flash_attn = False _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index 208989aff8..a3d1b4f9bf 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -563,8 +563,7 @@ class DogeDecoderLayer(GradientCheckpointingLayer): class DogePreTrainedModel(LlamaPreTrainedModel): - _supports_flash_attn_3 = False - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_static_cache = False _can_record_outputs = { "router_logits": OutputRecorder(DogeCDMoE, index=1), diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 0d2409f8e8..4ac1420e2d 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -414,8 +414,7 @@ class Dots1PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Dots1DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 4d0a072dcc..8614a4de6e 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -809,8 +809,7 @@ class DPTPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index c3dd0c1870..cdf5eee993 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -926,8 +926,7 @@ class Emu3VQVAE(PreTrainedModel): base_model_prefix = "emuvideovq" main_input_name = "pixel_values" _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True _no_split_modules = [ @@ -1096,8 +1095,7 @@ class Emu3PreTrainedModel(PreTrainedModel): "Emu3DecoderLayer", ] _skip_keys_device_placement = ["past_key_values", "causal_mask"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 6df9469498..58689de09b 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -678,8 +678,7 @@ class Emu3VQVAE(PreTrainedModel): base_model_prefix = "emuvideovq" main_input_name = "pixel_values" _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True _no_split_modules = [ diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 7bf808db22..48ec65a5e2 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -78,8 +78,7 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_param_buffer_assignment = False - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index 025d28350f..29c33cec05 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -1001,8 +1001,7 @@ class EomtPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _no_split_modules = ["EomtLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range diff --git a/src/transformers/models/eomt/modular_eomt.py b/src/transformers/models/eomt/modular_eomt.py index aeb562fc26..0a1b7dfd95 100644 --- a/src/transformers/models/eomt/modular_eomt.py +++ b/src/transformers/models/eomt/modular_eomt.py @@ -373,8 +373,7 @@ class EomtPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _no_split_modules = ["EomtLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index f448be910f..2cb697b08e 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -738,8 +738,7 @@ class EsmPreTrainedModel(PreTrainedModel): base_model_prefix = "esm" supports_gradient_checkpointing = True _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead def _init_weights(self, module): diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 81f4fcc4fe..8c74afdc7c 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -1990,7 +1990,7 @@ class EsmFoldingTrunk(nn.Module): ) class EsmForProteinFolding(EsmPreTrainedModel): _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"] - _supports_flash_attn_2 = False + _supports_flash_attn = False def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index ebcdf42086..9e31eb1c90 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -640,8 +640,7 @@ class FalconPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["FalconDecoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 20067056ba..82173de99f 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1150,8 +1150,7 @@ class FalconH1PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["FalconH1DecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True # Note: only supports FalconHybridMambaAttentionDynamicCache _is_stateful = True diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 13146e7bd1..89c6abc411 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -928,8 +928,7 @@ class FalconH1PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["FalconH1DecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True # Note: only supports FalconHybridMambaAttentionDynamicCache _is_stateful = True diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 589620ff80..d2838fa8e0 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -38,8 +38,7 @@ class FuyuPreTrainedModel(PreTrainedModel): base_model_prefix = "fuyu" supports_gradient_checkpointing = True _supports_attention_backend = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f79f62977d..e5b06aaa64 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -310,8 +310,7 @@ class GemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GemmaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 14fff44d8a..fceb1cf9d0 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -340,8 +340,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Gemma2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index d65aed200c..3f50d5f17b 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -430,8 +430,7 @@ class Gemma3PreTrainedModel(PreTrainedModel): "SiglipMultiheadAttentionPoolingHead", ] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 1e3641ddc6..3eb10a1f22 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1486,8 +1486,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Gemma3nTextDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 28e6295292..ccb4cb583a 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -327,8 +327,7 @@ class GlmPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GlmDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index ceb9f2d4dd..1a1b5abe57 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -331,8 +331,7 @@ class Glm4PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Glm4DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 3349d58402..1031969679 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -404,8 +404,7 @@ class Glm4vPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 2aec3cf284..c85d2c962f 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -281,8 +281,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 00d21c7d60..a98bf0c235 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -559,8 +559,7 @@ class GPT2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPT2Block"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_attention_backend = True _supports_cache_class = True diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 711c29a029..2d2418d1f7 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -655,8 +655,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTBigCodeBlock"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 9d606297e8..896b2123c6 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -476,8 +476,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # TODO: needs a HybridCache diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 599e7ee76f..7359d7b46e 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -360,8 +360,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoXLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 0aa03b559b..5c8e9e81c2 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -471,8 +471,7 @@ class GPTJPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index e94ef030f0..f2cf41c249 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -305,8 +305,7 @@ class GranitePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GraniteDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 2b0efc519e..447f222d85 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -283,8 +283,7 @@ class GraniteSpeechCTCEncoder(nn.Module): class GraniteSpeechPreTrainedModel(PreTrainedModel): config_class = GraniteSpeechConfig _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module: nn.Module): diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 065dfca74b..824c7ccd8b 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -589,8 +589,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GraniteMoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 8139103d21..fdfdae6112 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1162,8 +1162,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GraniteMoeHybridDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 475e97dc84..527ee691d4 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -507,8 +507,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GraniteMoeSharedDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 7cab557ac9..c238fa200f 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -312,8 +312,7 @@ class HeliumPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["HeliumDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index d828e0a49d..6f2d9cf3ae 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -684,8 +684,7 @@ class HubertPreTrainedModel(PreTrainedModel): base_model_prefix = "hubert" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index 5e0102c1d7..3e12c14e4c 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -129,8 +129,7 @@ class HubertPreTrainedModel(PreTrainedModel): base_model_prefix = "hubert" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index dd05709e08..cf5fad7bb5 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -879,8 +879,7 @@ class IdeficsPreTrainedModel(PreTrainedModel): _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] _supports_sdpa = True _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs _supports_attention_backend = True diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 06fccf9614..e596bce9de 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -456,8 +456,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True @@ -495,8 +494,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): class Idefics2VisionTransformer(Idefics2PreTrainedModel): config_class = Idefics2VisionConfig _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True def __init__(self, config: Idefics2VisionConfig): diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index c018963943..2461f4b956 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -473,8 +473,7 @@ class Idefics3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True @@ -506,8 +505,7 @@ class Idefics3PreTrainedModel(PreTrainedModel): class Idefics3VisionTransformer(Idefics3PreTrainedModel): config_class = Idefics3VisionConfig _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True def __init__(self, config: Idefics3VisionConfig): diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 6c2df8d2bb..2c16928f0a 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -151,8 +151,7 @@ class IJepaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index e3d0d79081..231753cea5 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -94,8 +94,7 @@ class IJepaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index de541f3716..ad94340b50 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -336,8 +336,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel): base_model_prefix = "blip" supports_gradient_checkpointing = True _supports_attention_backend = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True @@ -966,7 +965,7 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel): """ _supports_attention_backend = False # adds position on attn weights before last matmul - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index c0051d502c..68ed042a3d 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -823,8 +823,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): base_model_prefix = "blip" supports_gradient_checkpointing = True _supports_attention_backend = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True @@ -927,7 +926,7 @@ class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel): """ _supports_attention_backend = False # adds position on attn weights before last matmul - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 68074964c4..687a6dd1e3 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -179,8 +179,7 @@ class InternVLVisionPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["InternVLVisionLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True @@ -522,8 +521,7 @@ class InternVLPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 1c29907bb1..0e52b89872 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -141,8 +141,7 @@ class InternVLVisionPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["InternVLVisionLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index ce2f13eb0e..4b15f42ead 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1069,8 +1069,7 @@ class JambaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 9c31d8530d..bb8dec3b46 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -60,8 +60,7 @@ class JanusPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 37be985a42..df94df47b9 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -387,8 +387,7 @@ class JanusPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index f9bd85898c..7009c2d7d5 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -826,8 +826,7 @@ class JetMoePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _no_split_modules = ["JetMoeBlock"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index d9bfb7db13..da44dae8e3 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1151,8 +1151,7 @@ class Kosmos2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"] _supports_attention_backend = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index a067580cf0..41b5800b23 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -117,8 +117,7 @@ class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["KyutaiSpeechToTextDecoderLayer", "MimiTransformerLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True main_input_name = "input_ids" diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 4931a3a46e..1c9e026624 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -539,8 +539,7 @@ class Lfm2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Lfm2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index 37d6502e5f..5b863a7a93 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -423,8 +423,7 @@ class LightGluePreTrainedModel(PreTrainedModel): base_model_prefix = "lightglue" main_input_name = "pixel_values" supports_gradient_checkpointing = False - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module: nn.Module) -> None: diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py index beacacab02..544cad5c79 100644 --- a/src/transformers/models/lightglue/modular_lightglue.py +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -508,8 +508,7 @@ class LightGluePreTrainedModel(PreTrainedModel): base_model_prefix = "lightglue" main_input_name = "pixel_values" supports_gradient_checkpointing = False - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module: nn.Module) -> None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index cded4e58b3..3a078129b0 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -311,8 +311,7 @@ class LlamaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 653006d631..b3781f65c0 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -433,7 +433,7 @@ class Llama4PreTrainedModel(PreTrainedModel): config_class = Llama4Config supports_gradient_checkpointing = True _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index b7b67a8a77..b39360bb58 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -118,8 +118,7 @@ class LlavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 1c33fb8a04..249b85a559 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -229,8 +229,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index cf6435e1c7..8a218c5f4a 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -170,8 +170,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 6552f840a7..65c26c78de 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -283,8 +283,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index df2f6cdaa9..1f33da304f 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -529,8 +529,7 @@ class M2M100PreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index c4393a3948..3782d4bc3d 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -467,8 +467,7 @@ class MarianPreTrainedModel(PreTrainedModel): config_class = MarianConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 861eeaf68e..6f3ca84a24 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -497,8 +497,7 @@ class MBartPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 23208f3006..3ca8fe2b39 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1373,8 +1373,7 @@ class MimiPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MimiDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 2157315238..a5ad9c68ad 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -585,8 +585,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MiniMaxDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True # Note: only supports MiniMaxCache diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 8724a0b904..cc03b14c55 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -256,8 +256,7 @@ class MistralPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index d319e973ac..8abcab6d36 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -183,8 +183,7 @@ class Mistral3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 32f8f9a84b..ac64ace046 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -385,8 +385,7 @@ class MixtralPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index 4873217410..12fd0c6830 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -508,8 +508,7 @@ class MLCDPreTrainedModel(PreTrainedModel): config_class = MLCDVisionConfig base_model_prefix = "mlcd" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index 58fd45c9b8..a640ed0b59 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -442,8 +442,7 @@ class MLCDPreTrainedModel(PreTrainedModel): config_class = MLCDVisionConfig base_model_prefix = "mlcd" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 0f68c2d03d..0c418e70a9 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -852,8 +852,7 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_quantized_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index f4fd7cf37b..e6d6c3e712 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -560,8 +560,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = False @@ -612,36 +611,25 @@ class ModernBertPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - @classmethod - def _autoset_attn_implementation( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, dict[str, int]]] = None, - check_device_map: bool = True, - ): + def set_attention_implementation(self, attn_implementation: Union[str, dict]): + """ + Checks and dispatches to hhe requested attention implementation. + """ # If the user didn't specify anything, try to use flash_attention_2 if available. # Otherwise we fall back to the default SDPA -> Eager from the super() method. # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. - if config._attn_implementation_internal is None: - config._attn_implementation_internal = "flash_attention_2" - try: - return cls._check_and_enable_flash_attn_2( - config, - torch_dtype=torch.float16, - device_map=device_map, - hard_check_only=False, - check_device_map=check_device_map, - ) - except (ValueError, ImportError): - config._attn_implementation_internal = None - return super()._autoset_attn_implementation( - config, - torch_dtype=torch_dtype, - device_map=device_map, - check_device_map=check_device_map, - ) + + requested_attn_implementation = self._check_attn_implementation(attn_implementation) + try: + attn_implementation = ( + "flash_attention_2" + if requested_attn_implementation is None and self._flash_attn_2_can_dispatch() + else attn_implementation + ) + except (ValueError, ImportError): + pass + return super().set_attention_implementation(attn_implementation=attn_implementation) def _maybe_set_compile(self): if self.config.reference_compile is False: diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 94e45fcc5a..32e694d7d5 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -760,8 +760,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = False @@ -812,36 +811,25 @@ class ModernBertPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - @classmethod - def _autoset_attn_implementation( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, dict[str, int]]] = None, - check_device_map: bool = True, - ): + def set_attention_implementation(self, attn_implementation: Union[str, dict]): + """ + Checks and dispatches to hhe requested attention implementation. + """ # If the user didn't specify anything, try to use flash_attention_2 if available. # Otherwise we fall back to the default SDPA -> Eager from the super() method. # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. - if config._attn_implementation_internal is None: - config._attn_implementation_internal = "flash_attention_2" - try: - return cls._check_and_enable_flash_attn_2( - config, - torch_dtype=torch.float16, - device_map=device_map, - hard_check_only=False, - check_device_map=check_device_map, - ) - except (ValueError, ImportError): - config._attn_implementation_internal = None - return super()._autoset_attn_implementation( - config, - torch_dtype=torch_dtype, - device_map=device_map, - check_device_map=check_device_map, - ) + + requested_attn_implementation = self._check_attn_implementation(attn_implementation) + try: + attn_implementation = ( + "flash_attention_2" + if requested_attn_implementation is None and self._flash_attn_2_can_dispatch() + else attn_implementation + ) + except (ValueError, ImportError): + pass + return super().set_attention_implementation(attn_implementation=attn_implementation) def _maybe_set_compile(self): if self.config.reference_compile is False: diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 9b55fa8c96..33d0a75f37 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -459,8 +459,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 7180d35e8e..5e56fee5e0 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -494,8 +494,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 45bea58cdf..3b09eba5e0 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -805,8 +805,7 @@ class MoshiPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True main_input_name = "input_ids" @@ -1632,8 +1631,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): config_class = MoshiConfig main_input_name = "input_ids" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def __init__(self, config: MoshiConfig): diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 31161cb54a..2062a2333c 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -434,8 +434,7 @@ class MusicgenPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True @@ -1346,8 +1345,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): base_model_prefix = "encoder_decoder" main_input_name = "input_ids" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 6fd14cb842..7ad2b8c2d5 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -400,8 +400,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True @@ -1274,8 +1273,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class = MusicgenMelodyConfig main_input_name = "input_ids" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 30b3faa144..fe75aa212e 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -586,8 +586,7 @@ class NemotronPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["NemotronDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 48f184c078..1e0dddeac9 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -853,7 +853,7 @@ class NllbMoePreTrainedModel(PreTrainedModel): # TODO: If anyone is up to it to make sure tests pass etc # Flash attention has problems due to not preparing masks the same way as eager/sdpa # SDPA has more flaky logits which requires more time to look into tests - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 85a837ad38..8459ba57a6 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -290,8 +290,7 @@ class OlmoPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OlmoDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 8f0aa68d8b..8eb966f313 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -295,8 +295,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Olmo2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 0fa4d1961a..420d239b2d 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -703,8 +703,7 @@ class OlmoePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OlmoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index de52e72534..2f5a07f79a 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -309,8 +309,7 @@ class OPTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OPTDecoderLayer"] _supports_attention_backend = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 9ee315b8be..e3c17bc18e 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -116,8 +116,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index d8f8a511bc..9d0f3a9dd8 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -462,8 +462,7 @@ class PegasusPreTrainedModel(PreTrainedModel): config_class = PegasusConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 500150bba4..29dbbdf328 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -757,8 +757,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True # Flaky logits _supports_sdpa = False _supports_flex_attn = True diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index 4e00bed6d6..e0f09b574e 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -92,8 +92,7 @@ class PerceptionLMPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 5e62375845..af7d1db9e2 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -392,8 +392,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): _supports_quantized_cache = True _supports_static_cache = True _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 1782d80b4b..527574f613 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -295,8 +295,7 @@ class PhiPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PhiDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 6fd7430508..865a3973ad 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -287,8 +287,7 @@ class Phi3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Phi3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index f3ad3bafb8..a464e67e68 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -370,8 +370,7 @@ class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True @@ -996,8 +995,7 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): config_class = Phi4MultimodalAudioConfig supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True @@ -1591,8 +1589,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 4ccba03cce..d8c7c03b76 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -540,8 +540,7 @@ class Phi4MultimodalVisionPreTrainedModel(SiglipPreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True @@ -1121,8 +1120,7 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): config_class = Phi4MultimodalAudioConfig supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index e8f9e10600..328d749cd6 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -887,8 +887,7 @@ class PhimoePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PhimoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 5f9e135326..616f5810b2 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -405,13 +405,11 @@ class PixtralPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_attention_backend = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _no_split_modules = ["PixtralAttentionLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 3b071a1fe3..95835fd977 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -77,8 +77,7 @@ class PLBartPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 22200c0bbf..5202e61de8 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -62,8 +62,7 @@ class PLBartPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index d3882e22e6..2857895696 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -259,8 +259,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 2ac680b7fe..0136daf174 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -85,8 +85,7 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen2_5OmniDecoderLayer", "Qwen2_5OmniVisionBlock"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = False diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 47283174b9..b06150ccfa 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -323,8 +323,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 31cda06324..a689b3ecf3 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -267,8 +267,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen2AudioAttention"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index c3cce2f9b6..7ec40f2581 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -744,8 +744,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen2MoeDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index a5dd8099bd..a43c4d6a72 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -656,8 +656,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index a0c6d30239..2ebdbc9756 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -285,8 +285,7 @@ class Qwen3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 75d0ea8742..1d0f327523 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -408,8 +408,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen3MoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 299c90e2cf..1af81cf959 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -234,8 +234,7 @@ class RetrievAugLMOutput(ModelOutput): class RagPreTrainedModel(PreTrainedModel): config_class = RagConfig base_model_prefix = "rag" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True @classmethod diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 4b82c490cc..02af226c5b 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -509,7 +509,7 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["RecurrentGemmaDecoderLayer"] _skip_keys_device_placement = ["cache"] - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False # we can't compare with eager for now _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 863462eed6..b3f4388cb5 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -518,8 +518,7 @@ class SEWPreTrainedModel(PreTrainedModel): base_model_prefix = "sew" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = False # needs a proper look into the mask creation diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 2eba4010d6..b093987548 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -263,8 +263,7 @@ class SEWPreTrainedModel(PreTrainedModel): base_model_prefix = "sew" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = False # needs a proper look into the mask creation diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index d8fc3ed628..dfc252c473 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -479,8 +479,7 @@ class SiglipPreTrainedModel(PreTrainedModel): "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead", ] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index 25543b901a..2ff20c8b23 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -712,8 +712,7 @@ class Siglip2PreTrainedModel(PreTrainedModel): "Siglip2EncoderLayer", "Siglip2MultiheadAttentionPoolingHead", ] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index ed0233940d..da584a63fc 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -289,8 +289,7 @@ class SmolLM3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["SmolLM3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index befec29d9d..2fd0776edf 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -54,8 +54,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["SmolVLMVisionAttention", "SmolVLMDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True @@ -373,8 +372,7 @@ class SmolVLMEncoder(nn.Module): class SmolVLMVisionTransformer(SmolVLMPreTrainedModel): config_class = SmolVLMVisionConfig _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True def __init__(self, config: SmolVLMVisionConfig): diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 703584d43c..cbb1915022 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -66,8 +66,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin): main_input_name = "inputs" supports_gradient_checkpointing = True _supports_param_buffer_assignment = False - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 29debaeaac..3fcdf1af3b 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -528,7 +528,7 @@ class Speech2TextPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True # TODO: tests would need a rewrite to check for correct implementation # Current tests always assume certain inputs to be passed - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 0aef9d3ab7..66551b53cb 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -618,8 +618,7 @@ class StableLmPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["StableLmDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_cache_class = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index f4f805c016..c3f10cea11 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -293,8 +293,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Starcoder2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 6e31f18132..f37995ae32 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -581,8 +581,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["T5GemmaBlock"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 778a0485b4..0732ec7a3c 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -635,7 +635,7 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True # TODO: tests would need a rewrite to check for correct implementation # Current tests always assume certain inputs to be passed - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index e995b8c1dd..d1c2c0e00a 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -786,8 +786,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): base_model_prefix = "unispeech" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index ae7a37e093..0f4a98f9be 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -142,8 +142,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): base_model_prefix = "unispeech" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 86d4de0fe5..21bd79613d 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -791,8 +791,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): base_model_prefix = "unispeech_sat" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index a78b5bf2b0..087f83f958 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -154,8 +154,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): base_model_prefix = "unispeech_sat" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 4fa05080ad..ea94b07bfa 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -132,8 +132,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 74293b3971..39784ca889 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -471,8 +471,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 2826e7449e..94ced611b0 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -119,8 +119,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index fd92b19e5e..6bda9bf122 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -68,8 +68,7 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin): main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_param_buffer_assignment = False - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py index d8fa506d7c..f0d806c311 100755 --- a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py @@ -46,8 +46,7 @@ def clip_loss(similarity: torch.Tensor) -> torch.Tensor: class VisionTextDualEncoderModel(PreTrainedModel): config_class = VisionTextDualEncoderConfig base_model_prefix = "vision_text_dual_encoder" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index f898b81382..8e38f83cac 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -440,8 +440,7 @@ class ViTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ViTEmbeddings", "ViTLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 8f03768639..ddd582eca2 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -605,8 +605,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index c8b1b6f6cf..0c3da4fffa 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -444,8 +444,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index dad8dfe9c4..af2ca9825e 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -405,8 +405,7 @@ class VitPoseBackbonePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 8a2f317760..ca7c3046a5 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -448,8 +448,7 @@ class VivitPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index efbf452d1f..0226d7d9ad 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -989,8 +989,7 @@ class VJEPA2PreTrainedModel(PreTrainedModel): "VJEPA2PredictorEmbeddings", ] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index a06b74bfd6..88ca7c2540 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1035,8 +1035,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): base_model_prefix = "wav2vec2" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index ba5fba91d4..9f571c5dba 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -598,8 +598,7 @@ class WavLMPreTrainedModel(PreTrainedModel): base_model_prefix = "wavlm" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = False - _supports_flash_attn_3 = True + _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index aac25ff262..7666ed0561 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -508,7 +508,7 @@ class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): base_model_prefix = "wavlm" main_input_name = "input_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False _supports_flex_attn = False diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f8473cda9f..9442412b02 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -560,8 +560,7 @@ class WhisperPreTrainedModel(PreTrainedModel): main_input_name = "input_features" supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index bb9f034e81..96ae1fe1d9 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -518,8 +518,7 @@ class YolosPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 47c43bddc9..06d689ae68 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -786,7 +786,7 @@ class ZambaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ZambaAttentionDecoderLayer", "ZambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = False + _supports_flash_attn = False _supports_sdpa = False _supports_cache_class = True # Note: only supports ZambaHybridDynamicCache _is_stateful = True @@ -823,30 +823,6 @@ class ZambaPreTrainedModel(PreTrainedModel): module.A_log.data.copy_(torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1)) module.D.data.fill_(1.0) - @classmethod - @classmethod - def _check_and_enable_flash_attn_2( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, dict[str, int]]] = None, - hard_check_only: bool = False, - check_device_map: bool = False, - ): - """ - Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba models. - Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba v1. - """ - config = super()._check_and_enable_flash_attn_2( - config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map - ) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "flash_attention_2": - config._attn_implementation = "eager" - - return config - @auto_docstring class ZambaModel(ZambaPreTrainedModel): diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 5731d56643..ce7555058a 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1176,8 +1176,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_sdpa = True _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index a980d35828..032a2dd5cb 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -899,8 +899,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_sdpa = True _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache diff --git a/src/transformers/utils/args_doc.py b/src/transformers/utils/args_doc.py index 5028f4687a..f98bd89b34 100644 --- a/src/transformers/utils/args_doc.py +++ b/src/transformers/utils/args_doc.py @@ -927,11 +927,8 @@ class ClassAttrs: _skip_keys_device_placement = r""" A list of keys to ignore when moving inputs or outputs between devices when using the `accelerate` library. """ - _supports_flash_attn_3 = r""" - Whether the model's attention implementation supports FlashAttention 3.0. - """ - _supports_flash_attn_2 = r""" - Whether the model's attention implementation supports FlashAttention 2.0. + _supports_flash_attn = r""" + Whether the model's attention implementation supports FlashAttention. """ _supports_sdpa = r""" Whether the model's attention implementation supports SDPA (Scaled Dot Product Attention). diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6a82319cf9..ea00871d27 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -29,7 +29,15 @@ import pytest from packaging import version from parameterized import parameterized -from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, logging, pipeline +from transformers import ( + AutoConfig, + AutoProcessor, + AutoTokenizer, + PreTrainedModel, + is_torch_available, + logging, + pipeline, +) from transformers.testing_utils import ( CaptureLogger, is_flaky, @@ -2007,7 +2015,7 @@ class GenerationTesterMixin: max_new_tokens = 20 for dtype in (torch.float32, torch.float16): - model = model_class(config).to(torch_device).to(dtype).eval() + model = model_class(copy.deepcopy(config)).to(torch_device).to(dtype).eval() inputs_dict = { k: v.to(dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v for k, v in inputs_dict.items() @@ -2340,6 +2348,18 @@ class GenerationTesterMixin: set_config_for_less_flaky_test(config) model = model_class(config) + # If not all sub-models support flex, skip the test. We could potentially set not supported backbones + # to "eager" attention, leaving it for future updates on multimodality tests + sub_models_supporting_attn = [ + getattr(module, support_flag[attn_implementation]) + for name, module in model.named_modules() + if isinstance(module, PreTrainedModel) and name != "" + ] + if not all(sub_models_supporting_attn) and len(sub_models_supporting_attn) > 0: + self.skipTest( + f"One of {model_class.__name__}'s backbones does not support `attn_implementation={attn_implementation}`" + ) + with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) del model diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 8fab490859..4d24fc6e70 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch BLIP-2 model.""" +import copy import inspect import tempfile import unittest @@ -184,9 +185,8 @@ class Blip2VisionModelTest(ModelTesterMixin, unittest.TestCase): self.assertTrue(x is None or isinstance(x, nn.Linear)) def test_forward_signature(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) signature = inspect.signature(model.forward) # signature.parameters is an OrderedDict => so arg_names order is deterministic @@ -987,9 +987,8 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi raise ValueError("The eager model should not have SDPA attention layers") def test_forward_signature(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) signature = inspect.signature(model.forward) # signature.parameters is an OrderedDict => so arg_names order is deterministic @@ -1077,7 +1076,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi for key in ["vision_config", "qformer_config", "text_config"]: setattr(configs_no_init, key, _config_zero_init(getattr(configs_no_init, key))) for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) + model = model_class(config=copy.deepcopy(configs_no_init)) for name, param in model.named_parameters(): if param.requires_grad: self.assertIn( diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 96737ed12c..6b5f4b6614 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -187,9 +187,9 @@ class InstructBlipVideoVisionModelTest(ModelTesterMixin, unittest.TestCase): self.assertTrue(x is None or isinstance(x, nn.Linear)) def test_forward_signature(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) signature = inspect.signature(model.forward) # signature.parameters is an OrderedDict => so arg_names order is deterministic @@ -541,9 +541,8 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( pass def test_forward_signature(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) signature = inspect.signature(model.forward) # signature.parameters is an OrderedDict => so arg_names order is deterministic diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2550af4227..bed63445d3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -699,7 +699,7 @@ class ModelTesterMixin: def test_from_pretrained_no_checkpoint(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: - model = model_class(config) + model = model_class(copy.deepcopy(config)) state_dict = model.state_dict() new_model = model_class.from_pretrained( @@ -714,7 +714,7 @@ class ModelTesterMixin: if model_class._keep_in_fp32_modules is None: self.skipTest(reason="Model class has no _keep_in_fp32_modules attribute defined") - model = model_class(config) + model = model_class(copy.deepcopy(config)) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) @@ -730,7 +730,7 @@ class ModelTesterMixin: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: - model = model_class(config) + model = model_class(copy.deepcopy(config)) _keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None) if _keys_to_ignore_on_save is None: continue @@ -766,7 +766,7 @@ class ModelTesterMixin: continue config.gradient_checkpointing = True - model = model_class(config) + model = model_class(copy.deepcopy(config)) self.assertTrue(model.is_gradient_checkpointing) def test_gradient_checkpointing_enable_disable(self): @@ -777,7 +777,7 @@ class ModelTesterMixin: continue # at init model should have gradient checkpointing disabled - model = model_class(config) + model = model_class(copy.deepcopy(config)) self.assertFalse(model.is_gradient_checkpointing) # check enable works @@ -810,7 +810,7 @@ class ModelTesterMixin: continue # at init model should have gradient checkpointing disabled - model = model_class(config) + model = model_class(copy.deepcopy(config)) self.assertFalse(model.is_gradient_checkpointing) # check enable works @@ -871,7 +871,7 @@ class ModelTesterMixin: # First, initialize the model from config -> this ensure everything is correctly initialized, even if # _init_weights() does not take all weights into account correctly - model_from_config = model_class(config) + model_from_config = model_class(copy.deepcopy(config)) # Here, passing an empty state dict will force all weights to be moved from meta to cpu, then be initialized # by _init_weights() model_from_pretrained = model_class.from_pretrained(None, config=config, state_dict={}) @@ -944,7 +944,7 @@ class ModelTesterMixin: base_class_copy._init_weights = _mock_init_weights base_class_copy.init_weights = _mock_all_init_weights - model = model_class(config) + model = model_class(copy.deepcopy(config)) state_dict = model.state_dict() def check_equal(loaded): @@ -969,7 +969,7 @@ class ModelTesterMixin: configs_no_init = _config_zero_init(config) for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) + model = model_class(config=copy.deepcopy(configs_no_init)) for name, param in model.named_parameters(): if param.requires_grad: data = torch.flatten(param.data) @@ -1000,7 +1000,7 @@ class ModelTesterMixin: self.assertLessEqual(max_diff, 1e-5) for model_class in self.all_model_classes: - model = model_class(config) + model = model_class(copy.deepcopy(config)) model.to(torch_device) model.eval() with torch.no_grad(): @@ -1075,7 +1075,7 @@ class ModelTesterMixin: if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"): config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class) batched_input_prepared = self._prepare_for_class(batched_input, model_class) - model = model_class(config).to(torch_device).eval() + model = model_class(copy.deepcopy(config)).to(torch_device).eval() set_model_for_less_flaky_test(model) batch_size = self.model_tester.batch_size @@ -1932,7 +1932,7 @@ class ModelTesterMixin: def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): - model = model_class(config) + model = model_class(copy.deepcopy(config)) model.to(torch_device) model.eval() @@ -2061,16 +2061,15 @@ class ModelTesterMixin: ) = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: torch.manual_seed(0) - config = copy.deepcopy(original_config) - model = model_class(config) + model = model_class(copy.deepcopy(original_config)) model.to(torch_device) model.eval() hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0] torch.manual_seed(0) - config.chunk_size_feed_forward = 1 - model = model_class(config) + original_config.chunk_size_feed_forward = 1 + model = model_class(copy.deepcopy(original_config)) model.to(torch_device) model.eval() @@ -2445,7 +2444,7 @@ class ModelTesterMixin: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: - model = model_class(config) + model = model_class(copy.deepcopy(config)) self.assertIsInstance(model.get_input_embeddings(), nn.Embedding) new_input_embedding_layer = nn.Embedding(10, 10) @@ -2505,7 +2504,7 @@ class ModelTesterMixin: for model_class in self.all_model_classes: config.torchscript = True - model_not_tied = model_class(config) + model_not_tied = model_class(copy.deepcopy(config)) if model_not_tied.get_output_embeddings() is None: continue @@ -2582,7 +2581,7 @@ class ModelTesterMixin: config, _ = self.model_tester.prepare_config_and_inputs_for_common() config.get_text_config().tie_word_embeddings = True for model_class in self.all_model_classes: - model_tied = model_class(config) + model_tied = model_class(copy.deepcopy(config)) ptrs = collections.defaultdict(list) for name, tensor in model_tied.state_dict().items(): @@ -2707,7 +2706,7 @@ class ModelTesterMixin: recursive_check(tuple_output, dict_output) for model_class in self.all_model_classes: - model = model_class(config) + model = model_class(copy.deepcopy(config)) model.to(torch_device) model.eval() @@ -3033,7 +3032,7 @@ class ModelTesterMixin: continue inputs_dict_class = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config).eval() + model = model_class(copy.deepcopy(config)).eval() model = model.to(torch_device) torch.manual_seed(0) base_output = model(**inputs_dict_class) @@ -3077,7 +3076,7 @@ class ModelTesterMixin: continue inputs_dict_class = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config).eval() + model = model_class(copy.deepcopy(config)).eval() model = model.to(torch_device) torch.manual_seed(0) base_output = model(**inputs_dict_class) @@ -3115,7 +3114,7 @@ class ModelTesterMixin: continue inputs_dict_class = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config).eval() + model = model_class(copy.deepcopy(config)).eval() model = model.to(torch_device) torch.manual_seed(0) @@ -3470,7 +3469,7 @@ class ModelTesterMixin: config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: - model = model_class(config) + model = model_class(copy.deepcopy(config)) num_params = model.num_parameters() assert num_params < 1000000, ( f"{model_class} is too big for the common tests ({num_params})! It should have 1M max." @@ -3628,7 +3627,7 @@ class ModelTesterMixin: # set eager as it will be the one supported in all models # we just need to test if passing 'attn_implementation' as a dict fails or not - attn_implementation_per_subconfig = {} + attn_implementation_per_subconfig = {"": "eager"} for key in config.sub_configs.keys(): attn_implementation_per_subconfig[key] = "eager" @@ -4717,7 +4716,7 @@ class ModelTesterMixin: for model_class in self.all_model_classes: # If it does not raise here, the test passes with torch.device("meta"): - _ = model_class(config) + _ = model_class(copy.deepcopy(config)) @require_torch_accelerator def test_can_load_with_device_context_manager(self):