Llama Kernel integration (#37092)

* initial commit

* style

* update

* change approach attention

* clean up

* fix import

* update

* update

* fix style

* change method

* attention

* add mlp back

* change name

* update name

* fix copies

* fix config

* fix
This commit is contained in:
Mohamed Mekkouri
2025-04-10 17:13:25 +02:00
committed by GitHub
parent 9c0c323e12
commit 0ea1151222
31 changed files with 127 additions and 15 deletions

View File

@@ -25,6 +25,7 @@ from typing import Callable, List, Optional, Tuple, Union
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
@@ -61,6 +62,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "AriaTextConfig"
@use_kernel_forward_from_hub("RMSNorm")
class AriaTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
@@ -226,6 +228,7 @@ class AriaProjector(nn.Module):
return out
@use_kernel_forward_from_hub("MLP")
class AriaSharedExpertsMLP(nn.Module):
"""
Shared Expert MLP for shared experts.
@@ -563,6 +566,7 @@ class AriaTextAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
@@ -623,7 +627,6 @@ class AriaTextDecoderLayer(nn.Module):
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention