Llama Kernel integration (#37092)
* initial commit * style * update * change approach attention * clean up * fix import * update * update * fix style * change method * attention * add mlp back * change name * update name * fix copies * fix config * fix
This commit is contained in:
@@ -31,7 +31,20 @@ try:
|
|||||||
repo_id="kernels-community/deformable-detr",
|
repo_id="kernels-community/deformable-detr",
|
||||||
layer_name="MultiScaleDeformableAttention",
|
layer_name="MultiScaleDeformableAttention",
|
||||||
)
|
)
|
||||||
}
|
},
|
||||||
|
"RMSNorm": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-community/triton-layer-norm",
|
||||||
|
layer_name="LlamaRMSNorm",
|
||||||
|
revision="pure-layer-test",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"MLP": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="medmekk/triton-llama-mlp",
|
||||||
|
layer_name="TritonLlamaMLP",
|
||||||
|
)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
register_kernel_mapping(_KERNEL_MAPPING)
|
register_kernel_mapping(_KERNEL_MAPPING)
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ from .utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
|
is_kernels_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_optimum_available,
|
is_optimum_available,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
@@ -157,6 +158,9 @@ if is_safetensors_available():
|
|||||||
if is_deepspeed_available():
|
if is_deepspeed_available():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
|
if is_kernels_available():
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -2024,6 +2028,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
|
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(config._attn_implementation, str) and re.match(
|
||||||
|
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
|
||||||
|
):
|
||||||
|
if not is_kernels_available():
|
||||||
|
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||||
|
|
||||||
|
# Extract repo_id and kernel_name from the string
|
||||||
|
repo_id, kernel_name = config._attn_implementation.split(":")
|
||||||
|
kernel_name = kernel_name.strip()
|
||||||
|
repo_id = repo_id.strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
kernel = get_kernel(repo_id)
|
||||||
|
ALL_ATTENTION_FUNCTIONS.register(
|
||||||
|
f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)
|
||||||
|
)
|
||||||
|
config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
|
||||||
|
)
|
||||||
|
config._attn_implementation = "eager"
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
"the kernel function name or class specified in the attn_implementation argument is not valid. \
|
||||||
|
Please check the documentation for the correct format, \
|
||||||
|
and check that the kernel exports the class and the function correctly."
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not isinstance(config._attn_implementation, dict)
|
not isinstance(config._attn_implementation, dict)
|
||||||
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||||
@@ -4299,7 +4332,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||||
if not getattr(config, "_attn_implementation_autoset", False):
|
if not getattr(config, "_attn_implementation_autoset", False):
|
||||||
config = cls._autoset_attn_implementation(
|
config = cls._autoset_attn_implementation(
|
||||||
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
|
config,
|
||||||
|
use_flash_attention_2=use_flash_attention_2,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
device_map=device_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
with ContextManagers(model_init_context):
|
with ContextManagers(model_init_context):
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||||
@@ -61,6 +62,7 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "AriaTextConfig"
|
_CONFIG_FOR_DOC = "AriaTextConfig"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class AriaTextRMSNorm(nn.Module):
|
class AriaTextRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -226,6 +228,7 @@ class AriaProjector(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class AriaSharedExpertsMLP(nn.Module):
|
class AriaSharedExpertsMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
Shared Expert MLP for shared experts.
|
Shared Expert MLP for shared experts.
|
||||||
@@ -563,6 +566,7 @@ class AriaTextAttention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -623,7 +627,6 @@ class AriaTextDecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from transformers.activations import ACT2FN
|
|||||||
|
|
||||||
from ...cache_utils import Cache # we need __iter__ and __len__ of pkv
|
from ...cache_utils import Cache # we need __iter__ and __len__ of pkv
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
@@ -295,6 +296,7 @@ class BambaAttention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -880,6 +882,7 @@ class BambaMixer(nn.Module):
|
|||||||
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
|
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class BambaMLP(nn.Module):
|
class BambaMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -896,6 +899,7 @@ class BambaMLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class BambaRMSNorm(nn.Module):
|
class BambaRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
@@ -117,6 +118,7 @@ class CohereRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class CohereMLP(nn.Module):
|
class CohereMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import torch.nn as nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
@@ -267,6 +268,7 @@ class Cohere2Attention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class Cohere2MLP(nn.Module):
|
class Cohere2MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
@@ -44,6 +45,7 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "DeepseekV3Config"
|
_CONFIG_FOR_DOC = "DeepseekV3Config"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class DeepseekV3RMSNorm(nn.Module):
|
class DeepseekV3RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -481,7 +483,6 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import (
|
from ...modeling_flash_attention_utils import (
|
||||||
FlashAttentionKwargs,
|
FlashAttentionKwargs,
|
||||||
@@ -73,6 +74,7 @@ _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
|
|||||||
_CONFIG_FOR_DOC = "DiffLlamaConfig"
|
_CONFIG_FOR_DOC = "DiffLlamaConfig"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class DiffLlamaMLP(nn.Module):
|
class DiffLlamaMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -498,6 +500,7 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention):
|
|||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class DiffLlamaRMSNorm(nn.Module):
|
class DiffLlamaRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -549,7 +552,6 @@ class DiffLlamaDecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import torch.nn.functional as F
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
@@ -62,6 +63,7 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "Emu3Config"
|
_CONFIG_FOR_DOC = "Emu3Config"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Emu3RMSNorm(nn.Module):
|
class Emu3RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -82,6 +84,7 @@ class Emu3RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class Emu3MLP(nn.Module):
|
class Emu3MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -221,6 +224,7 @@ class Emu3Attention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -84,6 +85,7 @@ class GemmaRMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -257,6 +259,7 @@ class GemmaAttention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -306,7 +309,6 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import torch.nn as nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
@@ -77,6 +78,7 @@ class Gemma2RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class Gemma2MLP(nn.Module):
|
class Gemma2MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import torch.nn as nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
@@ -106,6 +107,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
|||||||
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class Gemma3MLP(nn.Module):
|
class Gemma3MLP(nn.Module):
|
||||||
def __init__(self, config: Gemma3TextConfig):
|
def __init__(self, config: Gemma3TextConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import torch.nn as nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -219,6 +220,7 @@ class GlmAttention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -244,6 +246,7 @@ class GlmAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class GlmRMSNorm(nn.Module):
|
class GlmRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -322,7 +325,6 @@ class GlmDecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import torch.nn as nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -277,6 +278,7 @@ class Glm4Attention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -305,6 +307,7 @@ class Glm4Attention(nn.Module):
|
|||||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Glm4RMSNorm(nn.Module):
|
class Glm4RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
@@ -180,6 +181,7 @@ class GraniteAttention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -205,6 +207,7 @@ class GraniteAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class GraniteRMSNorm(nn.Module):
|
class GraniteRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -225,6 +228,7 @@ class GraniteRMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class GraniteMLP(nn.Module):
|
class GraniteMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import torch.nn as nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -117,6 +118,7 @@ class HeliumRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class HeliumMLP(nn.Module):
|
class HeliumMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -260,6 +262,7 @@ class HeliumAttention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -309,7 +312,6 @@ class HeliumDecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ if is_torch_flex_attn_available():
|
|||||||
|
|
||||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||||
|
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@@ -66,6 +68,7 @@ _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
|
|||||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class LlamaRMSNorm(nn.Module):
|
class LlamaRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -157,6 +160,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -262,6 +266,7 @@ class LlamaAttention(nn.Module):
|
|||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -311,7 +316,6 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -44,6 +45,7 @@ _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
|
|||||||
_CONFIG_FOR_DOC = "MistralConfig"
|
_CONFIG_FOR_DOC = "MistralConfig"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class MistralMLP(nn.Module):
|
class MistralMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -200,6 +202,7 @@ class MistralAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class MistralRMSNorm(nn.Module):
|
class MistralRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -242,7 +245,6 @@ class MistralDecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from torch import nn
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_outputs import ModelOutput
|
from ...modeling_outputs import ModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -43,6 +44,7 @@ from .configuration_mistral3 import Mistral3Config
|
|||||||
_CONFIG_FOR_DOC = "Mistral3Config"
|
_CONFIG_FOR_DOC = "Mistral3Config"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Mistral3RMSNorm(nn.Module):
|
class Mistral3RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -152,6 +153,7 @@ class MixtralSparseMoeBlock(nn.Module):
|
|||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class MixtralRMSNorm(nn.Module):
|
class MixtralRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -381,7 +381,6 @@ class MoonshineEncoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
@@ -57,6 +58,7 @@ class OlmoLayerNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class OlmoMLP(nn.Module):
|
class OlmoMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -253,7 +255,6 @@ class OlmoDecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import torch.nn as nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
@@ -42,6 +43,7 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "Olmo2Config"
|
_CONFIG_FOR_DOC = "Olmo2Config"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Olmo2RMSNorm(nn.Module):
|
class Olmo2RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -216,6 +218,7 @@ class Olmo2Attention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class Olmo2MLP(nn.Module):
|
class Olmo2MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -229,6 +230,7 @@ class Phi3Attention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Phi3RMSNorm(nn.Module):
|
class Phi3RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -1281,6 +1282,7 @@ class Phi4MultimodalAudioEmbedding(nn.Module):
|
|||||||
return audio_embeds
|
return audio_embeds
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Phi4MultimodalRMSNorm(nn.Module):
|
class Phi4MultimodalRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -44,6 +45,7 @@ _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
|
|||||||
_CONFIG_FOR_DOC = "Qwen2Config"
|
_CONFIG_FOR_DOC = "Qwen2Config"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class Qwen2MLP(nn.Module):
|
class Qwen2MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -208,6 +210,7 @@ class Qwen2Attention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Qwen2RMSNorm(nn.Module):
|
class Qwen2RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -255,7 +258,6 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -59,6 +60,7 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B"
|
|||||||
_CONFIG_FOR_DOC = "Qwen3Config"
|
_CONFIG_FOR_DOC = "Qwen3Config"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Qwen3RMSNorm(nn.Module):
|
class Qwen3RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -79,6 +81,7 @@ class Qwen3RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("MLP")
|
||||||
class Qwen3MLP(nn.Module):
|
class Qwen3MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -282,7 +285,6 @@ class Qwen3DecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from torch import nn
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -289,6 +290,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Qwen3MoeRMSNorm(nn.Module):
|
class Qwen3MoeRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -246,7 +246,6 @@ class Starcoder2DecoderLayer(nn.Module):
|
|||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|||||||
@@ -156,6 +156,7 @@ from .import_utils import (
|
|||||||
is_jumanpp_available,
|
is_jumanpp_available,
|
||||||
is_kenlm_available,
|
is_kenlm_available,
|
||||||
is_keras_nlp_available,
|
is_keras_nlp_available,
|
||||||
|
is_kernels_available,
|
||||||
is_levenshtein_available,
|
is_levenshtein_available,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_liger_kernel_available,
|
is_liger_kernel_available,
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ _liger_kernel_available = _is_package_available("liger_kernel")
|
|||||||
_triton_available = _is_package_available("triton")
|
_triton_available = _is_package_available("triton")
|
||||||
_spqr_available = _is_package_available("spqr_quant")
|
_spqr_available = _is_package_available("spqr_quant")
|
||||||
_rich_available = _is_package_available("rich")
|
_rich_available = _is_package_available("rich")
|
||||||
|
_kernels_available = _is_package_available("kernels")
|
||||||
|
|
||||||
_torch_version = "N/A"
|
_torch_version = "N/A"
|
||||||
_torch_available = False
|
_torch_available = False
|
||||||
@@ -329,6 +330,10 @@ def is_kenlm_available():
|
|||||||
return _kenlm_available
|
return _kenlm_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_kernels_available():
|
||||||
|
return _kernels_available
|
||||||
|
|
||||||
|
|
||||||
def is_cv2_available():
|
def is_cv2_available():
|
||||||
return _cv2_available
|
return _cv2_available
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user