From 0ea1151222b0ba4bf8e509e5e7ae73b57359d296 Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Thu, 10 Apr 2025 17:13:25 +0200 Subject: [PATCH] Llama Kernel integration (#37092) * initial commit * style * update * change approach attention * clean up * fix import * update * update * fix style * change method * attention * add mlp back * change name * update name * fix copies * fix config * fix --- src/transformers/integrations/hub_kernels.py | 15 +++++++- src/transformers/modeling_utils.py | 38 ++++++++++++++++++- src/transformers/models/aria/modeling_aria.py | 5 ++- .../models/bamba/modeling_bamba.py | 4 ++ .../models/cohere/modeling_cohere.py | 2 + .../models/cohere2/modeling_cohere2.py | 2 + .../deepseek_v3/modeling_deepseek_v3.py | 3 +- .../models/diffllama/modeling_diffllama.py | 4 +- src/transformers/models/emu3/modeling_emu3.py | 4 ++ .../models/gemma/modeling_gemma.py | 4 +- .../models/gemma2/modeling_gemma2.py | 2 + .../models/gemma3/modeling_gemma3.py | 2 + src/transformers/models/glm/modeling_glm.py | 4 +- src/transformers/models/glm4/modeling_glm4.py | 3 ++ .../models/granite/modeling_granite.py | 4 ++ .../models/helium/modeling_helium.py | 4 +- .../models/llama/modeling_llama.py | 6 ++- .../models/mistral/modeling_mistral.py | 4 +- .../models/mistral3/modeling_mistral3.py | 2 + .../models/mixtral/modeling_mixtral.py | 2 + .../models/moonshine/modeling_moonshine.py | 1 - src/transformers/models/olmo/modeling_olmo.py | 3 +- .../models/olmo2/modeling_olmo2.py | 3 ++ src/transformers/models/phi3/modeling_phi3.py | 2 + .../modeling_phi4_multimodal.py | 2 + .../models/qwen2/modeling_qwen2.py | 4 +- .../models/qwen3/modeling_qwen3.py | 4 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 2 + .../models/starcoder2/modeling_starcoder2.py | 1 - src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 +++ 31 files changed, 127 insertions(+), 15 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index b2ec6b5371..ba41b2c0a8 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -31,7 +31,20 @@ try: repo_id="kernels-community/deformable-detr", layer_name="MultiScaleDeformableAttention", ) - } + }, + "RMSNorm": { + "cuda": LayerRepository( + repo_id="kernels-community/triton-layer-norm", + layer_name="LlamaRMSNorm", + revision="pure-layer-test", + ) + }, + "MLP": { + "cuda": LayerRepository( + repo_id="medmekk/triton-llama-mlp", + layer_name="TritonLlamaMLP", + ) + }, } register_kernel_mapping(_KERNEL_MAPPING) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 76526c360c..28b27cf114 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -102,6 +102,7 @@ from .utils import ( is_accelerate_available, is_bitsandbytes_available, is_flash_attn_2_available, + is_kernels_available, is_offline_mode, is_optimum_available, is_peft_available, @@ -157,6 +158,9 @@ if is_safetensors_available(): if is_deepspeed_available(): import deepspeed +if is_kernels_available(): + from kernels import get_kernel + logger = logging.get_logger(__name__) @@ -2024,6 +2028,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) + if isinstance(config._attn_implementation, str) and re.match( + r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation + ): + if not is_kernels_available(): + raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") + + # Extract repo_id and kernel_name from the string + repo_id, kernel_name = config._attn_implementation.split(":") + kernel_name = kernel_name.strip() + repo_id = repo_id.strip() + + try: + kernel = get_kernel(repo_id) + ALL_ATTENTION_FUNCTIONS.register( + f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name) + ) + config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}" + except FileNotFoundError as e: + logger.warning( + f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead." + ) + config._attn_implementation = "eager" + except AttributeError: + raise ValueError( + "the kernel function name or class specified in the attn_implementation argument is not valid. \ + Please check the documentation for the correct format, \ + and check that the kernel exports the class and the function correctly." + ) + if ( not isinstance(config._attn_implementation, dict) and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() @@ -4299,7 +4332,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. if not getattr(config, "_attn_implementation_autoset", False): config = cls._autoset_attn_implementation( - config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map + config, + use_flash_attention_2=use_flash_attention_2, + torch_dtype=torch_dtype, + device_map=device_map, ) with ContextManagers(model_init_context): diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c26fa30e0d..19997e2a9b 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -25,6 +25,7 @@ from typing import Callable, List, Optional, Tuple, Union from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput @@ -61,6 +62,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "AriaTextConfig" +@use_kernel_forward_from_hub("RMSNorm") class AriaTextRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -226,6 +228,7 @@ class AriaProjector(nn.Module): return out +@use_kernel_forward_from_hub("MLP") class AriaSharedExpertsMLP(nn.Module): """ Shared Expert MLP for shared experts. @@ -563,6 +566,7 @@ class AriaTextAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -623,7 +627,6 @@ class AriaTextDecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 47c0ee5a93..7c084d37ed 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -34,6 +34,7 @@ from transformers.activations import ACT2FN from ...cache_utils import Cache # we need __iter__ and __len__ of pkv from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -295,6 +296,7 @@ class BambaAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -880,6 +882,7 @@ class BambaMixer(nn.Module): return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) +@use_kernel_forward_from_hub("MLP") class BambaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -896,6 +899,7 @@ class BambaMLP(nn.Module): return down_proj +@use_kernel_forward_from_hub("RMSNorm") class BambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index bd52795979..2083bb63ab 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -36,6 +36,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -117,6 +118,7 @@ class CohereRotaryEmbedding(nn.Module): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +@use_kernel_forward_from_hub("MLP") class CohereMLP(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 18a3a50ac1..e419379969 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -28,6 +28,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -267,6 +268,7 @@ class Cohere2Attention(nn.Module): return attn_output, attn_weights +@use_kernel_forward_from_hub("MLP") class Cohere2MLP(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 7d932a4f15..e7ad5c5cf0 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -15,6 +15,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -44,6 +45,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV3Config" +@use_kernel_forward_from_hub("RMSNorm") class DeepseekV3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -481,7 +483,6 @@ class DeepseekV3DecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index e444a42323..e8031be755 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -31,6 +31,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, @@ -73,6 +74,7 @@ _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut" _CONFIG_FOR_DOC = "DiffLlamaConfig" +@use_kernel_forward_from_hub("MLP") class DiffLlamaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -498,6 +500,7 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): return attn_output, None +@use_kernel_forward_from_hub("RMSNorm") class DiffLlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -549,7 +552,6 @@ class DiffLlamaDecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 15d89e3677..341944a236 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -31,6 +31,7 @@ import torch.nn.functional as F from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -62,6 +63,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Emu3Config" +@use_kernel_forward_from_hub("RMSNorm") class Emu3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -82,6 +84,7 @@ class Emu3RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_forward_from_hub("MLP") class Emu3MLP(nn.Module): def __init__(self, config): super().__init__() @@ -221,6 +224,7 @@ class Emu3Attention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9b349a4381..99e65dbae9 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -27,6 +27,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -84,6 +85,7 @@ class GemmaRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.eps}" +@use_kernel_forward_from_hub("MLP") class GemmaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -257,6 +259,7 @@ class GemmaAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -306,7 +309,6 @@ class GemmaDecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 144a94ef33..c7040de011 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -28,6 +28,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -77,6 +78,7 @@ class Gemma2RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.eps}" +@use_kernel_forward_from_hub("MLP") class Gemma2MLP(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 0988e2692a..23f28281a1 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -31,6 +31,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -106,6 +107,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) +@use_kernel_forward_from_hub("MLP") class Gemma3MLP(nn.Module): def __init__(self, config: Gemma3TextConfig): super().__init__() diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 07365a495f..f2acc45c66 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -28,6 +28,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -219,6 +220,7 @@ class GlmAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -244,6 +246,7 @@ class GlmAttention(nn.Module): return attn_output, attn_weights +@use_kernel_forward_from_hub("RMSNorm") class GlmRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -322,7 +325,6 @@ class GlmDecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 95c1c6543c..c30ced7197 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -28,6 +28,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -277,6 +278,7 @@ class Glm4Attention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -305,6 +307,7 @@ class Glm4Attention(nn.Module): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... +@use_kernel_forward_from_hub("RMSNorm") class Glm4RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index e65f432291..4160a658f8 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -28,6 +28,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -180,6 +181,7 @@ class GraniteAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -205,6 +207,7 @@ class GraniteAttention(nn.Module): return attn_output, attn_weights +@use_kernel_forward_from_hub("RMSNorm") class GraniteRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -225,6 +228,7 @@ class GraniteRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_forward_from_hub("MLP") class GraniteMLP(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 37350ae462..599c8a5667 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -29,6 +29,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -117,6 +118,7 @@ class HeliumRotaryEmbedding(nn.Module): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +@use_kernel_forward_from_hub("MLP") class HeliumMLP(nn.Module): def __init__(self, config): super().__init__() @@ -260,6 +262,7 @@ class HeliumAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -309,7 +312,6 @@ class HeliumDecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 938fd7d326..1a598f21a3 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -59,6 +59,8 @@ if is_torch_flex_attn_available(): from ...integrations.flex_attention import make_flex_block_causal_mask +from ...integrations import use_kernel_forward_from_hub + logger = logging.get_logger(__name__) @@ -66,6 +68,7 @@ _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" _CONFIG_FOR_DOC = "LlamaConfig" +@use_kernel_forward_from_hub("RMSNorm") class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -157,6 +160,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +@use_kernel_forward_from_hub("MLP") class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -262,6 +266,7 @@ class LlamaAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( @@ -311,7 +316,6 @@ class LlamaDecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5639c3bbb6..72a9c88eac 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -13,6 +13,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -44,6 +45,7 @@ _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" _CONFIG_FOR_DOC = "MistralConfig" +@use_kernel_forward_from_hub("MLP") class MistralMLP(nn.Module): def __init__(self, config): super().__init__() @@ -200,6 +202,7 @@ class MistralAttention(nn.Module): return attn_output, attn_weights +@use_kernel_forward_from_hub("RMSNorm") class MistralRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -242,7 +245,6 @@ class MistralDecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 8ef1328466..6d60202db4 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -27,6 +27,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -43,6 +44,7 @@ from .configuration_mistral3 import Mistral3Config _CONFIG_FOR_DOC = "Mistral3Config" +@use_kernel_forward_from_hub("RMSNorm") class Mistral3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 51604dec3f..cb4e0a9c36 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -34,6 +34,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -152,6 +153,7 @@ class MixtralSparseMoeBlock(nn.Module): return final_hidden_states, router_logits +@use_kernel_forward_from_hub("RMSNorm") class MixtralRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 317c7ac6df..6e470f7bf7 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -381,7 +381,6 @@ class MoonshineEncoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 4ad98556ee..8b015057ef 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -57,6 +58,7 @@ class OlmoLayerNorm(nn.Module): ) +@use_kernel_forward_from_hub("MLP") class OlmoMLP(nn.Module): def __init__(self, config): super().__init__() @@ -253,7 +255,6 @@ class OlmoDecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 3404c5e817..c99ba1f0de 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -13,6 +13,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -42,6 +43,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Olmo2Config" +@use_kernel_forward_from_hub("RMSNorm") class Olmo2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -216,6 +218,7 @@ class Olmo2Attention(nn.Module): return attn_output, attn_weights +@use_kernel_forward_from_hub("MLP") class Olmo2MLP(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 1c8688b28f..0b0d7b626b 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -29,6 +29,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -229,6 +230,7 @@ class Phi3Attention(nn.Module): return attn_output, attn_weights +@use_kernel_forward_from_hub("RMSNorm") class Phi3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index fb399f6d83..68a9ae5536 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -33,6 +33,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -1281,6 +1282,7 @@ class Phi4MultimodalAudioEmbedding(nn.Module): return audio_embeds +@use_kernel_forward_from_hub("RMSNorm") class Phi4MultimodalRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 16a7316e2d..661e3181d7 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -13,6 +13,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -44,6 +45,7 @@ _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" _CONFIG_FOR_DOC = "Qwen2Config" +@use_kernel_forward_from_hub("MLP") class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() @@ -208,6 +210,7 @@ class Qwen2Attention(nn.Module): return attn_output, attn_weights +@use_kernel_forward_from_hub("RMSNorm") class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -255,7 +258,6 @@ class Qwen2DecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 5fec83d478..dbf31b6d52 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,6 +28,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -59,6 +60,7 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B" _CONFIG_FOR_DOC = "Qwen3Config" +@use_kernel_forward_from_hub("RMSNorm") class Qwen3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -79,6 +81,7 @@ class Qwen3RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_forward_from_hub("MLP") class Qwen3MLP(nn.Module): def __init__(self, config): super().__init__() @@ -282,7 +285,6 @@ class Qwen3DecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 61e4a88049..e56b121b3a 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -29,6 +29,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -289,6 +290,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): return final_hidden_states, router_logits +@use_kernel_forward_from_hub("RMSNorm") class Qwen3MoeRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 089cb00dce..4ca5ecf636 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -246,7 +246,6 @@ class Starcoder2DecoderLayer(nn.Module): **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 3a1c2e5998..221bb39c84 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -156,6 +156,7 @@ from .import_utils import ( is_jumanpp_available, is_kenlm_available, is_keras_nlp_available, + is_kernels_available, is_levenshtein_available, is_librosa_available, is_liger_kernel_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 184b618e7e..75c88dd019 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -216,6 +216,7 @@ _liger_kernel_available = _is_package_available("liger_kernel") _triton_available = _is_package_available("triton") _spqr_available = _is_package_available("spqr_quant") _rich_available = _is_package_available("rich") +_kernels_available = _is_package_available("kernels") _torch_version = "N/A" _torch_available = False @@ -329,6 +330,10 @@ def is_kenlm_available(): return _kenlm_available +def is_kernels_available(): + return _kernels_available + + def is_cv2_available(): return _cv2_available