Llama Kernel integration (#37092)

* initial commit

* style

* update

* change approach attention

* clean up

* fix import

* update

* update

* fix style

* change method

* attention

* add mlp back

* change name

* update name

* fix copies

* fix config

* fix
This commit is contained in:
Mohamed Mekkouri
2025-04-10 17:13:25 +02:00
committed by GitHub
parent 9c0c323e12
commit 0ea1151222
31 changed files with 127 additions and 15 deletions

View File

@@ -31,7 +31,20 @@ try:
repo_id="kernels-community/deformable-detr", repo_id="kernels-community/deformable-detr",
layer_name="MultiScaleDeformableAttention", layer_name="MultiScaleDeformableAttention",
) )
} },
"RMSNorm": {
"cuda": LayerRepository(
repo_id="kernels-community/triton-layer-norm",
layer_name="LlamaRMSNorm",
revision="pure-layer-test",
)
},
"MLP": {
"cuda": LayerRepository(
repo_id="medmekk/triton-llama-mlp",
layer_name="TritonLlamaMLP",
)
},
} }
register_kernel_mapping(_KERNEL_MAPPING) register_kernel_mapping(_KERNEL_MAPPING)

View File

@@ -102,6 +102,7 @@ from .utils import (
is_accelerate_available, is_accelerate_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_kernels_available,
is_offline_mode, is_offline_mode,
is_optimum_available, is_optimum_available,
is_peft_available, is_peft_available,
@@ -157,6 +158,9 @@ if is_safetensors_available():
if is_deepspeed_available(): if is_deepspeed_available():
import deepspeed import deepspeed
if is_kernels_available():
from kernels import get_kernel
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -2024,6 +2028,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
) )
if isinstance(config._attn_implementation, str) and re.match(
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = config._attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(
f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)
)
config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning(
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
)
config._attn_implementation = "eager"
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. \
Please check the documentation for the correct format, \
and check that the kernel exports the class and the function correctly."
)
if ( if (
not isinstance(config._attn_implementation, dict) not isinstance(config._attn_implementation, dict)
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
@@ -4299,7 +4332,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False): if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation( config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
device_map=device_map,
) )
with ContextManagers(model_init_context): with ContextManagers(model_init_context):

View File

@@ -25,6 +25,7 @@ from typing import Callable, List, Optional, Tuple, Union
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
@@ -61,6 +62,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "AriaTextConfig" _CONFIG_FOR_DOC = "AriaTextConfig"
@use_kernel_forward_from_hub("RMSNorm")
class AriaTextRMSNorm(nn.Module): class AriaTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -226,6 +228,7 @@ class AriaProjector(nn.Module):
return out return out
@use_kernel_forward_from_hub("MLP")
class AriaSharedExpertsMLP(nn.Module): class AriaSharedExpertsMLP(nn.Module):
""" """
Shared Expert MLP for shared experts. Shared Expert MLP for shared experts.
@@ -563,6 +566,7 @@ class AriaTextAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@@ -623,7 +627,6 @@ class AriaTextDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -34,6 +34,7 @@ from transformers.activations import ACT2FN
from ...cache_utils import Cache # we need __iter__ and __len__ of pkv from ...cache_utils import Cache # we need __iter__ and __len__ of pkv
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -295,6 +296,7 @@ class BambaAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@@ -880,6 +882,7 @@ class BambaMixer(nn.Module):
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
@use_kernel_forward_from_hub("MLP")
class BambaMLP(nn.Module): class BambaMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -896,6 +899,7 @@ class BambaMLP(nn.Module):
return down_proj return down_proj
@use_kernel_forward_from_hub("RMSNorm")
class BambaRMSNorm(nn.Module): class BambaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """

View File

@@ -36,6 +36,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -117,6 +118,7 @@ class CohereRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@use_kernel_forward_from_hub("MLP")
class CohereMLP(nn.Module): class CohereMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()

View File

@@ -28,6 +28,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@@ -267,6 +268,7 @@ class Cohere2Attention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
@use_kernel_forward_from_hub("MLP")
class Cohere2MLP(nn.Module): class Cohere2MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()

View File

@@ -15,6 +15,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -44,6 +45,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DeepseekV3Config" _CONFIG_FOR_DOC = "DeepseekV3Config"
@use_kernel_forward_from_hub("RMSNorm")
class DeepseekV3RMSNorm(nn.Module): class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -481,7 +483,6 @@ class DeepseekV3DecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -31,6 +31,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import ( from ...modeling_flash_attention_utils import (
FlashAttentionKwargs, FlashAttentionKwargs,
@@ -73,6 +74,7 @@ _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
_CONFIG_FOR_DOC = "DiffLlamaConfig" _CONFIG_FOR_DOC = "DiffLlamaConfig"
@use_kernel_forward_from_hub("MLP")
class DiffLlamaMLP(nn.Module): class DiffLlamaMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -498,6 +500,7 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention):
return attn_output, None return attn_output, None
@use_kernel_forward_from_hub("RMSNorm")
class DiffLlamaRMSNorm(nn.Module): class DiffLlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -549,7 +552,6 @@ class DiffLlamaDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -31,6 +31,7 @@ import torch.nn.functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -62,6 +63,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Emu3Config" _CONFIG_FOR_DOC = "Emu3Config"
@use_kernel_forward_from_hub("RMSNorm")
class Emu3RMSNorm(nn.Module): class Emu3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -82,6 +84,7 @@ class Emu3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@use_kernel_forward_from_hub("MLP")
class Emu3MLP(nn.Module): class Emu3MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -221,6 +224,7 @@ class Emu3Attention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(

View File

@@ -27,6 +27,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -84,6 +85,7 @@ class GemmaRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}" return f"{tuple(self.weight.shape)}, eps={self.eps}"
@use_kernel_forward_from_hub("MLP")
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -257,6 +259,7 @@ class GemmaAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@@ -306,7 +309,6 @@ class GemmaDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -28,6 +28,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@@ -77,6 +78,7 @@ class Gemma2RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}" return f"{tuple(self.weight.shape)}, eps={self.eps}"
@use_kernel_forward_from_hub("MLP")
class Gemma2MLP(nn.Module): class Gemma2MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()

View File

@@ -31,6 +31,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@@ -106,6 +107,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
@use_kernel_forward_from_hub("MLP")
class Gemma3MLP(nn.Module): class Gemma3MLP(nn.Module):
def __init__(self, config: Gemma3TextConfig): def __init__(self, config: Gemma3TextConfig):
super().__init__() super().__init__()

View File

@@ -28,6 +28,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -219,6 +220,7 @@ class GlmAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@@ -244,6 +246,7 @@ class GlmAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
@use_kernel_forward_from_hub("RMSNorm")
class GlmRMSNorm(nn.Module): class GlmRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -322,7 +325,6 @@ class GlmDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -28,6 +28,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -277,6 +278,7 @@ class Glm4Attention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@@ -305,6 +307,7 @@ class Glm4Attention(nn.Module):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@use_kernel_forward_from_hub("RMSNorm")
class Glm4RMSNorm(nn.Module): class Glm4RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """

View File

@@ -28,6 +28,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -180,6 +181,7 @@ class GraniteAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@@ -205,6 +207,7 @@ class GraniteAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
@use_kernel_forward_from_hub("RMSNorm")
class GraniteRMSNorm(nn.Module): class GraniteRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -225,6 +228,7 @@ class GraniteRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@use_kernel_forward_from_hub("MLP")
class GraniteMLP(nn.Module): class GraniteMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()

View File

@@ -29,6 +29,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -117,6 +118,7 @@ class HeliumRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@use_kernel_forward_from_hub("MLP")
class HeliumMLP(nn.Module): class HeliumMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -260,6 +262,7 @@ class HeliumAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@@ -309,7 +312,6 @@ class HeliumDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -59,6 +59,8 @@ if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask from ...integrations.flex_attention import make_flex_block_causal_mask
from ...integrations import use_kernel_forward_from_hub
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -66,6 +68,7 @@ _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
_CONFIG_FOR_DOC = "LlamaConfig" _CONFIG_FOR_DOC = "LlamaConfig"
@use_kernel_forward_from_hub("RMSNorm")
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -157,6 +160,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed return q_embed, k_embed
@use_kernel_forward_from_hub("MLP")
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -262,6 +266,7 @@ class LlamaAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@@ -311,7 +316,6 @@ class LlamaDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -13,6 +13,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -44,6 +45,7 @@ _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
_CONFIG_FOR_DOC = "MistralConfig" _CONFIG_FOR_DOC = "MistralConfig"
@use_kernel_forward_from_hub("MLP")
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -200,6 +202,7 @@ class MistralAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
@use_kernel_forward_from_hub("RMSNorm")
class MistralRMSNorm(nn.Module): class MistralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -242,7 +245,6 @@ class MistralDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -27,6 +27,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_outputs import ModelOutput from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
@@ -43,6 +44,7 @@ from .configuration_mistral3 import Mistral3Config
_CONFIG_FOR_DOC = "Mistral3Config" _CONFIG_FOR_DOC = "Mistral3Config"
@use_kernel_forward_from_hub("RMSNorm")
class Mistral3RMSNorm(nn.Module): class Mistral3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """

View File

@@ -34,6 +34,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -152,6 +153,7 @@ class MixtralSparseMoeBlock(nn.Module):
return final_hidden_states, router_logits return final_hidden_states, router_logits
@use_kernel_forward_from_hub("RMSNorm")
class MixtralRMSNorm(nn.Module): class MixtralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """

View File

@@ -381,7 +381,6 @@ class MoonshineEncoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -14,6 +14,7 @@ import torch.nn.functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -57,6 +58,7 @@ class OlmoLayerNorm(nn.Module):
) )
@use_kernel_forward_from_hub("MLP")
class OlmoMLP(nn.Module): class OlmoMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -253,7 +255,6 @@ class OlmoDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -13,6 +13,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -42,6 +43,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Olmo2Config" _CONFIG_FOR_DOC = "Olmo2Config"
@use_kernel_forward_from_hub("RMSNorm")
class Olmo2RMSNorm(nn.Module): class Olmo2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -216,6 +218,7 @@ class Olmo2Attention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
@use_kernel_forward_from_hub("MLP")
class Olmo2MLP(nn.Module): class Olmo2MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()

View File

@@ -29,6 +29,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -229,6 +230,7 @@ class Phi3Attention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
@use_kernel_forward_from_hub("RMSNorm")
class Phi3RMSNorm(nn.Module): class Phi3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """

View File

@@ -33,6 +33,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -1281,6 +1282,7 @@ class Phi4MultimodalAudioEmbedding(nn.Module):
return audio_embeds return audio_embeds
@use_kernel_forward_from_hub("RMSNorm")
class Phi4MultimodalRMSNorm(nn.Module): class Phi4MultimodalRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """

View File

@@ -13,6 +13,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -44,6 +45,7 @@ _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
_CONFIG_FOR_DOC = "Qwen2Config" _CONFIG_FOR_DOC = "Qwen2Config"
@use_kernel_forward_from_hub("MLP")
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -208,6 +210,7 @@ class Qwen2Attention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
@use_kernel_forward_from_hub("RMSNorm")
class Qwen2RMSNorm(nn.Module): class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -255,7 +258,6 @@ class Qwen2DecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -28,6 +28,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -59,6 +60,7 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B"
_CONFIG_FOR_DOC = "Qwen3Config" _CONFIG_FOR_DOC = "Qwen3Config"
@use_kernel_forward_from_hub("RMSNorm")
class Qwen3RMSNorm(nn.Module): class Qwen3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -79,6 +81,7 @@ class Qwen3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@use_kernel_forward_from_hub("MLP")
class Qwen3MLP(nn.Module): class Qwen3MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -282,7 +285,6 @@ class Qwen3DecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -29,6 +29,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
@@ -289,6 +290,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
return final_hidden_states, router_logits return final_hidden_states, router_logits
@use_kernel_forward_from_hub("RMSNorm")
class Qwen3MoeRMSNorm(nn.Module): class Qwen3MoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """

View File

@@ -246,7 +246,6 @@ class Starcoder2DecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention

View File

@@ -156,6 +156,7 @@ from .import_utils import (
is_jumanpp_available, is_jumanpp_available,
is_kenlm_available, is_kenlm_available,
is_keras_nlp_available, is_keras_nlp_available,
is_kernels_available,
is_levenshtein_available, is_levenshtein_available,
is_librosa_available, is_librosa_available,
is_liger_kernel_available, is_liger_kernel_available,

View File

@@ -216,6 +216,7 @@ _liger_kernel_available = _is_package_available("liger_kernel")
_triton_available = _is_package_available("triton") _triton_available = _is_package_available("triton")
_spqr_available = _is_package_available("spqr_quant") _spqr_available = _is_package_available("spqr_quant")
_rich_available = _is_package_available("rich") _rich_available = _is_package_available("rich")
_kernels_available = _is_package_available("kernels")
_torch_version = "N/A" _torch_version = "N/A"
_torch_available = False _torch_available = False
@@ -329,6 +330,10 @@ def is_kenlm_available():
return _kenlm_available return _kenlm_available
def is_kernels_available():
return _kernels_available
def is_cv2_available(): def is_cv2_available():
return _cv2_available return _cv2_available