Llama Kernel integration (#37092)
* initial commit * style * update * change approach attention * clean up * fix import * update * update * fix style * change method * attention * add mlp back * change name * update name * fix copies * fix config * fix
This commit is contained in:
@@ -25,6 +25,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||
@@ -61,6 +62,7 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "AriaTextConfig"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class AriaTextRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
@@ -226,6 +228,7 @@ class AriaProjector(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class AriaSharedExpertsMLP(nn.Module):
|
||||
"""
|
||||
Shared Expert MLP for shared experts.
|
||||
@@ -563,6 +566,7 @@ class AriaTextAttention(nn.Module):
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
@@ -623,7 +627,6 @@ class AriaTextDecoderLayer(nn.Module):
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
|
||||
Reference in New Issue
Block a user