[Ernie 4.5] Post merge adaptations (#39664)

* ernie 4.5 fixes

* Apply style fixes

* fix

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Anton Vlasjuk
2025-07-25 17:36:18 +02:00
committed by GitHub
parent 5d0ba3e479
commit a91653561e
10 changed files with 126 additions and 101 deletions

View File

@@ -31,7 +31,7 @@ The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.b
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core. model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core.
Other models from the family can be found at [Ernie 4.5 MoE](./ernie4_5_moe.md). Other models from the family can be found at [Ernie 4.5 Moe](./ernie4_5_moe.md).
<div class="flex justify-center"> <div class="flex justify-center">
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/> <img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>

View File

@@ -23,11 +23,11 @@ rendered properly in your Markdown viewer.
</div> </div>
</div> </div>
# Ernie 4.5 MoE # Ernie 4.5 Moe
## Overview ## Overview
The Ernie 4.5 MoE model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu. The Ernie 4.5 Moe model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters. model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters.
It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared
@@ -167,17 +167,17 @@ This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV).
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE). The original code can be found [here](https://github.com/PaddlePaddle/ERNIE).
## Ernie4_5_MoEConfig ## Ernie4_5_MoeConfig
[[autodoc]] Ernie4_5_MoEConfig [[autodoc]] Ernie4_5_MoeConfig
## Ernie4_5_MoEModel ## Ernie4_5_MoeModel
[[autodoc]] Ernie4_5_MoEModel [[autodoc]] Ernie4_5_MoeModel
- forward - forward
## Ernie4_5_MoEForCausalLM ## Ernie4_5_MoeForCausalLM
[[autodoc]] Ernie4_5_MoEForCausalLM [[autodoc]] Ernie4_5_MoeForCausalLM
- forward - forward
- generate - generate

View File

@@ -130,7 +130,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("eomt", "EomtConfig"), ("eomt", "EomtConfig"),
("ernie", "ErnieConfig"), ("ernie", "ErnieConfig"),
("ernie4_5", "Ernie4_5Config"), ("ernie4_5", "Ernie4_5Config"),
("ernie4_5_moe", "Ernie4_5_MoEConfig"), ("ernie4_5_moe", "Ernie4_5_MoeConfig"),
("ernie_m", "ErnieMConfig"), ("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"), ("esm", "EsmConfig"),
("falcon", "FalconConfig"), ("falcon", "FalconConfig"),

View File

@@ -121,7 +121,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("encodec", "EncodecModel"), ("encodec", "EncodecModel"),
("ernie", "ErnieModel"), ("ernie", "ErnieModel"),
("ernie4_5", "Ernie4_5Model"), ("ernie4_5", "Ernie4_5Model"),
("ernie4_5_moe", "Ernie4_5_MoEModel"), ("ernie4_5_moe", "Ernie4_5_MoeModel"),
("ernie_m", "ErnieMModel"), ("ernie_m", "ErnieMModel"),
("esm", "EsmModel"), ("esm", "EsmModel"),
("falcon", "FalconModel"), ("falcon", "FalconModel"),
@@ -597,7 +597,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3ForCausalLM"), ("emu3", "Emu3ForCausalLM"),
("ernie", "ErnieForCausalLM"), ("ernie", "ErnieForCausalLM"),
("ernie4_5", "Ernie4_5ForCausalLM"), ("ernie4_5", "Ernie4_5ForCausalLM"),
("ernie4_5_moe", "Ernie4_5_MoEForCausalLM"), ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
("falcon", "FalconForCausalLM"), ("falcon", "FalconForCausalLM"),
("falcon_h1", "FalconH1ForCausalLM"), ("falcon_h1", "FalconH1ForCausalLM"),
("falcon_mamba", "FalconMambaForCausalLM"), ("falcon_mamba", "FalconMambaForCausalLM"),

View File

@@ -21,9 +21,9 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Ernie4_5_MoEConfig(PretrainedConfig): class Ernie4_5_MoeConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`Ernie4_5_MoEModel`]. It is used to instantiate a This is the configuration class to store the configuration of a [`Ernie4_5_MoeModel`]. It is used to instantiate a
Ernie 4.5 MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration Ernie 4.5 MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of [baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT). with the defaults will yield a similar configuration to that of [baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT).
@@ -34,7 +34,7 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 103424): vocab_size (`int`, *optional*, defaults to 103424):
Vocabulary size of the Ernie 4.5 MoE model. Defines the number of different tokens that can be represented by the Vocabulary size of the Ernie 4.5 MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Ernie4_5_MoEModel`] `inputs_ids` passed when calling [`Ernie4_5_MoeModel`]
pad_token_id (`int`, *optional*, defaults to 0): pad_token_id (`int`, *optional*, defaults to 0):
Padding token id. Padding token id.
bos_token_id (`int`, *optional*, defaults to 1): bos_token_id (`int`, *optional*, defaults to 1):
@@ -133,13 +133,13 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
The aux loss factor for the total loss. The aux loss factor for the total loss.
```python ```python
>>> from transformers import Ernie4_5_MoEModel, Ernie4_5_MoEConfig >>> from transformers import Ernie4_5_MoeModel, Ernie4_5_MoEConfig
>>> # Initializing a Ernie4_5_MoE style configuration >>> # Initializing a Ernie4_5_MoE style configuration
>>> configuration = Ernie4_5_MoEConfig() >>> configuration = Ernie4_5_MoEConfig()
>>> # Initializing a model from the ERNIE-4.5-21B-A3B style configuration >>> # Initializing a model from the ERNIE-4.5-21B-A3B style configuration
>>> model = Ernie4_5_MoEModel(configuration) >>> model = Ernie4_5_MoeModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
@@ -251,4 +251,4 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
) )
__all__ = ["Ernie4_5_MoEConfig"] __all__ = ["Ernie4_5_MoeConfig"]

View File

@@ -37,14 +37,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import OutputRecorder, check_model_inputs from ...utils.generic import OutputRecorder, check_model_inputs
from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
@use_kernel_forward_from_hub("RMSNorm") @use_kernel_forward_from_hub("RMSNorm")
class Ernie4_5_MoERMSNorm(nn.Module): class Ernie4_5_MoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
Ernie4_5_MoERMSNorm is equivalent to T5LayerNorm Ernie4_5_MoeRMSNorm is equivalent to T5LayerNorm
""" """
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -61,7 +61,7 @@ class Ernie4_5_MoERMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Ernie4_5_MoEMLP(nn.Module): class Ernie4_5_MoeMLP(nn.Module):
def __init__(self, config, intermediate_size=None): def __init__(self, config, intermediate_size=None):
super().__init__() super().__init__()
self.config = config self.config = config
@@ -78,8 +78,8 @@ class Ernie4_5_MoEMLP(nn.Module):
return down_proj return down_proj
class Ernie4_5_MoERotaryEmbedding(nn.Module): class Ernie4_5_MoeRotaryEmbedding(nn.Module):
def __init__(self, config: Ernie4_5_MoEConfig, device=None): def __init__(self, config: Ernie4_5_MoeConfig, device=None):
super().__init__() super().__init__()
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
@@ -194,10 +194,10 @@ def eager_attention_forward(
return attn_output, attn_weights return attn_output, attn_weights
class Ernie4_5_MoEAttention(nn.Module): class Ernie4_5_MoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int): def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
@@ -257,7 +257,7 @@ class Ernie4_5_MoEAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class Ernie4_5_MoEStatics(nn.Module): class Ernie4_5_MoeStatics(nn.Module):
""" """
Stores MoE (Mixture of Experts) statistics Stores MoE (Mixture of Experts) statistics
- Bias for the gating - Bias for the gating
@@ -284,7 +284,7 @@ class Ernie4_5_MoEStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze() return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoESparseMoeBlock(nn.Module): class Ernie4_5_MoeSparseMoeBlock(nn.Module):
""" """
This implementation is This implementation is
strictly equivalent to standard MoE with full capacity (no strictly equivalent to standard MoE with full capacity (no
@@ -305,19 +305,19 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
self.top_k = config.moe_k self.top_k = config.moe_k
# correction bias (yes it seems to be a typo with statics <> statistics) # correction bias (yes it seems to be a typo with statics <> statistics)
self.moe_statics = Ernie4_5_MoEStatics(config) self.moe_statics = Ernie4_5_MoeStatics(config)
# gating # gating
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.experts = nn.ModuleList( self.experts = nn.ModuleList(
[Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)] [Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
) )
self.norm_min = config.moe_norm_min self.norm_min = config.moe_norm_min
# (optional) shared experts for all forwards # (optional) shared experts for all forwards
self.shared_experts = None self.shared_experts = None
if config.moe_num_shared_experts > 0: if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def forward( def forward(
self, self,
@@ -379,24 +379,24 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
return final_hidden_states, router_logits return final_hidden_states, router_logits
class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer): class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Ernie4_5_MoEAttention(config, layer_idx) self.self_attn = Ernie4_5_MoeAttention(config, layer_idx)
if ( if (
((layer_idx + 1) % config.moe_layer_interval == 0) ((layer_idx + 1) % config.moe_layer_interval == 0)
and layer_idx >= config.moe_layer_start_index and layer_idx >= config.moe_layer_start_index
and layer_idx <= config.moe_layer_end_index and layer_idx <= config.moe_layer_end_index
): ):
self.mlp = Ernie4_5_MoESparseMoeBlock(config) self.mlp = Ernie4_5_MoeSparseMoeBlock(config)
else: else:
self.mlp = Ernie4_5_MoEMLP(config) self.mlp = Ernie4_5_MoeMLP(config)
self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps) self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps) self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
def forward( def forward(
self, self,
@@ -461,11 +461,11 @@ class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer):
@auto_docstring @auto_docstring
class Ernie4_5_MoEPreTrainedModel(PreTrainedModel): class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
config: Ernie4_5_MoEConfig config: Ernie4_5_MoeConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Ernie4_5_MoEDecoderLayer"] _no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True _supports_flash_attn = True
_supports_sdpa = True _supports_sdpa = True
@@ -473,31 +473,33 @@ class Ernie4_5_MoEPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True _supports_attention_backend = True
_can_record_outputs = { _can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoESparseMoeBlock, index=1), "router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1),
"hidden_states": Ernie4_5_MoEDecoderLayer, "hidden_states": Ernie4_5_MoeDecoderLayer,
"attentions": Ernie4_5_MoEAttention, "attentions": Ernie4_5_MoeAttention,
} }
_keep_in_fp32_modules_strict = ["gate", "moe_statics"] _keep_in_fp32_modules_strict = ["gate", "moe_statics"]
# Not supporting multi-token prediction (MTP) atm
_keys_to_ignore_on_load_unexpected = ["mtp"]
def _init_weights(self, module): def _init_weights(self, module):
super()._init_weights(module) super()._init_weights(module)
if isinstance(module, Ernie4_5_MoEStatics): if isinstance(module, Ernie4_5_MoeStatics):
module.e_score_correction_bias.data.zero_() module.e_score_correction_bias.data.zero_()
@auto_docstring @auto_docstring
class Ernie4_5_MoEModel(Ernie4_5_MoEPreTrainedModel): class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel):
def __init__(self, config: Ernie4_5_MoEConfig): def __init__(self, config: Ernie4_5_MoeConfig):
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[Ernie4_5_MoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
self.norm = Ernie4_5_MoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Ernie4_5_MoERotaryEmbedding(config=config) self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
@@ -650,14 +652,14 @@ def load_balancing_loss_func(
@auto_docstring @auto_docstring
class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin): class Ernie4_5_MoeForCausalLM(Ernie4_5_MoePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model = Ernie4_5_MoEModel(config) self.model = Ernie4_5_MoeModel(config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias)
@@ -745,4 +747,4 @@ class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin):
) )
__all__ = ["Ernie4_5_MoEForCausalLM", "Ernie4_5_MoEModel", "Ernie4_5_MoEPreTrainedModel"] __all__ = ["Ernie4_5_MoeForCausalLM", "Ernie4_5_MoeModel", "Ernie4_5_MoePreTrainedModel"]

View File

@@ -24,26 +24,25 @@ from ...masking_utils import create_causal_mask
from ...modeling_outputs import MoeModelOutputWithPast from ...modeling_outputs import MoeModelOutputWithPast
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import check_model_inputs from ...utils.generic import OutputRecorder, check_model_inputs
from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401 from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm
from ..mixtral.modeling_mixtral import ( from ..mixtral.modeling_mixtral import (
MixtralForCausalLM, MixtralForCausalLM,
MixtralModel,
MixtralPreTrainedModel, MixtralPreTrainedModel,
) )
from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeMLP from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeMLP
from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Ernie4_5_MoERMSNorm(LlamaRMSNorm): class Ernie4_5_MoeRMSNorm(LlamaRMSNorm):
pass pass
class Ernie4_5_MoEMLP(Qwen3MoeMLP): class Ernie4_5_MoeMLP(Qwen3MoeMLP):
def __init__(self, config, intermediate_size=None): def __init__(self, config, intermediate_size=None):
super().__init__(config, intermediate_size) super().__init__(config, intermediate_size)
@@ -52,12 +51,13 @@ class Ernie4_5_MoEMLP(Qwen3MoeMLP):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
class Ernie4_5_MoERotaryEmbedding(Ernie4_5RotaryEmbedding): class Ernie4_5_MoeRotaryEmbedding(Ernie4_5RotaryEmbedding):
pass def __init__(self, config: Ernie4_5_MoeConfig, device=None):
super().__init__(config, device)
class Ernie4_5_MoEAttention(LlamaAttention): class Ernie4_5_MoeAttention(LlamaAttention):
def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int): def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
super().__init__(config, layer_idx) super().__init__(config, layer_idx)
self.attention_dropout = 0.0 self.attention_dropout = 0.0
@@ -68,7 +68,7 @@ class Ernie4_5_MoEAttention(LlamaAttention):
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
class Ernie4_5_MoEStatics(nn.Module): class Ernie4_5_MoeStatics(nn.Module):
""" """
Stores MoE (Mixture of Experts) statistics Stores MoE (Mixture of Experts) statistics
- Bias for the gating - Bias for the gating
@@ -95,7 +95,7 @@ class Ernie4_5_MoEStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze() return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoESparseMoeBlock(nn.Module): class Ernie4_5_MoeSparseMoeBlock(nn.Module):
""" """
This implementation is This implementation is
strictly equivalent to standard MoE with full capacity (no strictly equivalent to standard MoE with full capacity (no
@@ -116,19 +116,19 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
self.top_k = config.moe_k self.top_k = config.moe_k
# correction bias (yes it seems to be a typo with statics <> statistics) # correction bias (yes it seems to be a typo with statics <> statistics)
self.moe_statics = Ernie4_5_MoEStatics(config) self.moe_statics = Ernie4_5_MoeStatics(config)
# gating # gating
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.experts = nn.ModuleList( self.experts = nn.ModuleList(
[Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)] [Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
) )
self.norm_min = config.moe_norm_min self.norm_min = config.moe_norm_min
# (optional) shared experts for all forwards # (optional) shared experts for all forwards
self.shared_experts = None self.shared_experts = None
if config.moe_num_shared_experts > 0: if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def forward( def forward(
self, self,
@@ -190,38 +190,63 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
return final_hidden_states, router_logits return final_hidden_states, router_logits
class Ernie4_5_MoEDecoderLayer(Qwen3MoeDecoderLayer, nn.Module): class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
nn.Module().__init__() nn.Module().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Ernie4_5_MoEAttention(config, layer_idx) self.self_attn = Ernie4_5_MoeAttention(config, layer_idx)
if ( if (
((layer_idx + 1) % config.moe_layer_interval == 0) ((layer_idx + 1) % config.moe_layer_interval == 0)
and layer_idx >= config.moe_layer_start_index and layer_idx >= config.moe_layer_start_index
and layer_idx <= config.moe_layer_end_index and layer_idx <= config.moe_layer_end_index
): ):
self.mlp = Ernie4_5_MoESparseMoeBlock(config) self.mlp = Ernie4_5_MoeSparseMoeBlock(config)
else: else:
self.mlp = Ernie4_5_MoEMLP(config) self.mlp = Ernie4_5_MoeMLP(config)
self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps) self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps) self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
@auto_docstring @auto_docstring
class Ernie4_5_MoEPreTrainedModel(MixtralPreTrainedModel): class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
config: Ernie4_5_MoeConfig
_no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
_keep_in_fp32_modules_strict = ["gate", "moe_statics"] _keep_in_fp32_modules_strict = ["gate", "moe_statics"]
# Not supporting multi-token prediction (MTP) atm
_keys_to_ignore_on_load_unexpected = ["mtp"]
_can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1),
"hidden_states": Ernie4_5_MoeDecoderLayer,
"attentions": Ernie4_5_MoeAttention,
}
def _init_weights(self, module): def _init_weights(self, module):
MixtralPreTrainedModel._init_weights(module) MixtralPreTrainedModel._init_weights(module)
if isinstance(module, Ernie4_5_MoEStatics): if isinstance(module, Ernie4_5_MoeStatics):
module.e_score_correction_bias.data.zero_() module.e_score_correction_bias.data.zero_()
@auto_docstring @auto_docstring
class Ernie4_5_MoEModel(MixtralModel): class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel):
def __init__(self, config: Ernie4_5_MoeConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs @check_model_inputs
@auto_docstring @auto_docstring
def forward( def forward(
@@ -287,10 +312,10 @@ class Ernie4_5_MoEModel(MixtralModel):
@auto_docstring @auto_docstring
class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel): class Ernie4_5_MoeForCausalLM(MixtralForCausalLM, Ernie4_5_MoePreTrainedModel):
def __init__(self, config): def __init__(self, config):
Ernie4_5_MoEPreTrainedModel().__init__(config) Ernie4_5_MoePreTrainedModel().__init__(config)
self.model = Ernie4_5_MoEModel(config) self.model = Ernie4_5_MoeModel(config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias)
@@ -314,7 +339,7 @@ class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel):
__all__ = [ __all__ = [
"Ernie4_5_MoEForCausalLM", "Ernie4_5_MoeForCausalLM",
"Ernie4_5_MoEModel", "Ernie4_5_MoeModel",
"Ernie4_5_MoEPreTrainedModel", "Ernie4_5_MoePreTrainedModel",
] ]

View File

@@ -102,7 +102,6 @@ class Ernie4_5IntegrationTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3") tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3")
model = Ernie4_5ForCausalLM.from_pretrained( model = Ernie4_5ForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-0.3B-PT", "baidu/ERNIE-4.5-0.3B-PT",
revision="refs/pr/3",
device_map="auto", device_map="auto",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )

View File

@@ -18,7 +18,7 @@ import unittest
import pytest import pytest
from transformers import Ernie4_5_MoEConfig, is_torch_available from transformers import Ernie4_5_MoeConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
cleanup, cleanup,
is_flaky, is_flaky,
@@ -38,33 +38,33 @@ if is_torch_available():
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
Ernie4_5_MoEForCausalLM, Ernie4_5_MoeForCausalLM,
Ernie4_5_MoEModel, Ernie4_5_MoeModel,
) )
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
class Ernie4_5_MoEModelTester(CausalLMModelTester): class Ernie4_5_MoeModelTester(CausalLMModelTester):
config_class = Ernie4_5_MoEConfig config_class = Ernie4_5_MoeConfig
if is_torch_available(): if is_torch_available():
base_model_class = Ernie4_5_MoEModel base_model_class = Ernie4_5_MoeModel
causal_lm_class = Ernie4_5_MoEForCausalLM causal_lm_class = Ernie4_5_MoeForCausalLM
@require_torch @require_torch
class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase): class Ernie4_5_MoeModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
Ernie4_5_MoEModel, Ernie4_5_MoeModel,
Ernie4_5_MoEForCausalLM, Ernie4_5_MoeForCausalLM,
) )
if is_torch_available() if is_torch_available()
else () else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": Ernie4_5_MoEModel, "feature-extraction": Ernie4_5_MoeModel,
"text-generation": Ernie4_5_MoEForCausalLM, "text-generation": Ernie4_5_MoeForCausalLM,
} }
if is_torch_available() if is_torch_available()
else {} else {}
@@ -73,7 +73,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
test_headmasking = False test_headmasking = False
test_pruning = False test_pruning = False
test_all_params_have_gradient = False test_all_params_have_gradient = False
model_tester_class = Ernie4_5_MoEModelTester model_tester_class = Ernie4_5_MoeModelTester
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@@ -82,7 +82,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
@slow @slow
def test_flash_attn_2_equivalence(self): def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if not model_class._supports_flash_attn:
self.skipTest(reason="Model does not support Flash Attention 2") self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -123,7 +123,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
config.output_router_logits = True config.output_router_logits = True
input_ids = input_dict["input_ids"] input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device) attention_mask = input_ids.ne(1).to(torch_device)
model = Ernie4_5_MoEForCausalLM(config) model = Ernie4_5_MoeForCausalLM(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=attention_mask) result = model(input_ids, attention_mask=attention_mask)
@@ -153,7 +153,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
@require_torch_multi_accelerator @require_torch_multi_accelerator
@require_torch_large_accelerator @require_torch_large_accelerator
@require_torch @require_torch
class Ernie4_5_MoEIntegrationTest(unittest.TestCase): class Ernie4_5_MoeIntegrationTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = None cls.model = None
@@ -169,9 +169,8 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
@classmethod @classmethod
def get_model(cls): def get_model(cls):
if cls.model is None: if cls.model is None:
cls.model = Ernie4_5_MoEForCausalLM.from_pretrained( cls.model = Ernie4_5_MoeForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-21B-A3B-PT", "baidu/ERNIE-4.5-21B-A3B-PT",
revision="refs/pr/11",
device_map="auto", device_map="auto",
load_in_4bit=True, load_in_4bit=True,
) )

View File

@@ -33,7 +33,7 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = { SPECIAL_CASES_TO_ALLOW = {
"Ernie4_5Config": ["tie_word_embeddings"], "Ernie4_5Config": ["tie_word_embeddings"],
"Ernie4_5_MoEConfig": ["tie_word_embeddings"], "Ernie4_5_MoeConfig": ["tie_word_embeddings"],
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"], "Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
# used internally during generation to provide the custom logit processors with their necessary information # used internally during generation to provide the custom logit processors with their necessary information
"DiaConfig": [ "DiaConfig": [