diff --git a/docs/source/en/model_doc/ernie4_5.md b/docs/source/en/model_doc/ernie4_5.md
index b350b9d429..af24c0b8bf 100644
--- a/docs/source/en/model_doc/ernie4_5.md
+++ b/docs/source/en/model_doc/ernie4_5.md
@@ -31,7 +31,7 @@ The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.b
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core.
-Other models from the family can be found at [Ernie 4.5 MoE](./ernie4_5_moe.md).
+Other models from the family can be found at [Ernie 4.5 Moe](./ernie4_5_moe.md).

diff --git a/docs/source/en/model_doc/ernie4_5_moe.md b/docs/source/en/model_doc/ernie4_5_moe.md
index 9d8703e592..8c3e75593c 100644
--- a/docs/source/en/model_doc/ernie4_5_moe.md
+++ b/docs/source/en/model_doc/ernie4_5_moe.md
@@ -23,11 +23,11 @@ rendered properly in your Markdown viewer.
-# Ernie 4.5 MoE
+# Ernie 4.5 Moe
## Overview
-The Ernie 4.5 MoE model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
+The Ernie 4.5 Moe model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters.
It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared
@@ -167,17 +167,17 @@ This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV).
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE).
-## Ernie4_5_MoEConfig
+## Ernie4_5_MoeConfig
-[[autodoc]] Ernie4_5_MoEConfig
+[[autodoc]] Ernie4_5_MoeConfig
-## Ernie4_5_MoEModel
+## Ernie4_5_MoeModel
-[[autodoc]] Ernie4_5_MoEModel
+[[autodoc]] Ernie4_5_MoeModel
- forward
-## Ernie4_5_MoEForCausalLM
+## Ernie4_5_MoeForCausalLM
-[[autodoc]] Ernie4_5_MoEForCausalLM
+[[autodoc]] Ernie4_5_MoeForCausalLM
- forward
- generate
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 78c2ad034b..75a328d3cc 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -130,7 +130,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("eomt", "EomtConfig"),
("ernie", "ErnieConfig"),
("ernie4_5", "Ernie4_5Config"),
- ("ernie4_5_moe", "Ernie4_5_MoEConfig"),
+ ("ernie4_5_moe", "Ernie4_5_MoeConfig"),
("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"),
("falcon", "FalconConfig"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 20f039b22b..0574346832 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -121,7 +121,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("encodec", "EncodecModel"),
("ernie", "ErnieModel"),
("ernie4_5", "Ernie4_5Model"),
- ("ernie4_5_moe", "Ernie4_5_MoEModel"),
+ ("ernie4_5_moe", "Ernie4_5_MoeModel"),
("ernie_m", "ErnieMModel"),
("esm", "EsmModel"),
("falcon", "FalconModel"),
@@ -597,7 +597,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3ForCausalLM"),
("ernie", "ErnieForCausalLM"),
("ernie4_5", "Ernie4_5ForCausalLM"),
- ("ernie4_5_moe", "Ernie4_5_MoEForCausalLM"),
+ ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
("falcon", "FalconForCausalLM"),
("falcon_h1", "FalconH1ForCausalLM"),
("falcon_mamba", "FalconMambaForCausalLM"),
diff --git a/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py
index cec4f4661a..294ccfc638 100644
--- a/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py
+++ b/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py
@@ -21,9 +21,9 @@ from ...utils import logging
logger = logging.get_logger(__name__)
-class Ernie4_5_MoEConfig(PretrainedConfig):
+class Ernie4_5_MoeConfig(PretrainedConfig):
r"""
- This is the configuration class to store the configuration of a [`Ernie4_5_MoEModel`]. It is used to instantiate a
+ This is the configuration class to store the configuration of a [`Ernie4_5_MoeModel`]. It is used to instantiate a
Ernie 4.5 MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of [baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT).
@@ -34,7 +34,7 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 103424):
Vocabulary size of the Ernie 4.5 MoE model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`Ernie4_5_MoEModel`]
+ `inputs_ids` passed when calling [`Ernie4_5_MoeModel`]
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
@@ -133,13 +133,13 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
The aux loss factor for the total loss.
```python
- >>> from transformers import Ernie4_5_MoEModel, Ernie4_5_MoEConfig
+ >>> from transformers import Ernie4_5_MoeModel, Ernie4_5_MoEConfig
>>> # Initializing a Ernie4_5_MoE style configuration
>>> configuration = Ernie4_5_MoEConfig()
>>> # Initializing a model from the ERNIE-4.5-21B-A3B style configuration
- >>> model = Ernie4_5_MoEModel(configuration)
+ >>> model = Ernie4_5_MoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
@@ -251,4 +251,4 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
)
-__all__ = ["Ernie4_5_MoEConfig"]
+__all__ = ["Ernie4_5_MoeConfig"]
diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
index 14e598bff9..d583c27321 100644
--- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
+++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
@@ -37,14 +37,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import OutputRecorder, check_model_inputs
-from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig
+from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
@use_kernel_forward_from_hub("RMSNorm")
-class Ernie4_5_MoERMSNorm(nn.Module):
+class Ernie4_5_MoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
- Ernie4_5_MoERMSNorm is equivalent to T5LayerNorm
+ Ernie4_5_MoeRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -61,7 +61,7 @@ class Ernie4_5_MoERMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
-class Ernie4_5_MoEMLP(nn.Module):
+class Ernie4_5_MoeMLP(nn.Module):
def __init__(self, config, intermediate_size=None):
super().__init__()
self.config = config
@@ -78,8 +78,8 @@ class Ernie4_5_MoEMLP(nn.Module):
return down_proj
-class Ernie4_5_MoERotaryEmbedding(nn.Module):
- def __init__(self, config: Ernie4_5_MoEConfig, device=None):
+class Ernie4_5_MoeRotaryEmbedding(nn.Module):
+ def __init__(self, config: Ernie4_5_MoeConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
@@ -194,10 +194,10 @@ def eager_attention_forward(
return attn_output, attn_weights
-class Ernie4_5_MoEAttention(nn.Module):
+class Ernie4_5_MoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int):
+ def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@@ -257,7 +257,7 @@ class Ernie4_5_MoEAttention(nn.Module):
return attn_output, attn_weights
-class Ernie4_5_MoEStatics(nn.Module):
+class Ernie4_5_MoeStatics(nn.Module):
"""
Stores MoE (Mixture of Experts) statistics
- Bias for the gating
@@ -284,7 +284,7 @@ class Ernie4_5_MoEStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
-class Ernie4_5_MoESparseMoeBlock(nn.Module):
+class Ernie4_5_MoeSparseMoeBlock(nn.Module):
"""
This implementation is
strictly equivalent to standard MoE with full capacity (no
@@ -305,19 +305,19 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
self.top_k = config.moe_k
# correction bias (yes it seems to be a typo with statics <> statistics)
- self.moe_statics = Ernie4_5_MoEStatics(config)
+ self.moe_statics = Ernie4_5_MoeStatics(config)
# gating
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.experts = nn.ModuleList(
- [Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
+ [Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
)
self.norm_min = config.moe_norm_min
# (optional) shared experts for all forwards
self.shared_experts = None
if config.moe_num_shared_experts > 0:
- self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
+ self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def forward(
self,
@@ -379,24 +379,24 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
return final_hidden_states, router_logits
-class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer):
+class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = Ernie4_5_MoEAttention(config, layer_idx)
+ self.self_attn = Ernie4_5_MoeAttention(config, layer_idx)
if (
((layer_idx + 1) % config.moe_layer_interval == 0)
and layer_idx >= config.moe_layer_start_index
and layer_idx <= config.moe_layer_end_index
):
- self.mlp = Ernie4_5_MoESparseMoeBlock(config)
+ self.mlp = Ernie4_5_MoeSparseMoeBlock(config)
else:
- self.mlp = Ernie4_5_MoEMLP(config)
+ self.mlp = Ernie4_5_MoeMLP(config)
- self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
- self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(
self,
@@ -461,11 +461,11 @@ class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer):
@auto_docstring
-class Ernie4_5_MoEPreTrainedModel(PreTrainedModel):
- config: Ernie4_5_MoEConfig
+class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
+ config: Ernie4_5_MoeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
- _no_split_modules = ["Ernie4_5_MoEDecoderLayer"]
+ _no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
@@ -473,31 +473,33 @@ class Ernie4_5_MoEPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
- "router_logits": OutputRecorder(Ernie4_5_MoESparseMoeBlock, index=1),
- "hidden_states": Ernie4_5_MoEDecoderLayer,
- "attentions": Ernie4_5_MoEAttention,
+ "router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1),
+ "hidden_states": Ernie4_5_MoeDecoderLayer,
+ "attentions": Ernie4_5_MoeAttention,
}
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
+ # Not supporting multi-token prediction (MTP) atm
+ _keys_to_ignore_on_load_unexpected = ["mtp"]
def _init_weights(self, module):
super()._init_weights(module)
- if isinstance(module, Ernie4_5_MoEStatics):
+ if isinstance(module, Ernie4_5_MoeStatics):
module.e_score_correction_bias.data.zero_()
@auto_docstring
-class Ernie4_5_MoEModel(Ernie4_5_MoEPreTrainedModel):
- def __init__(self, config: Ernie4_5_MoEConfig):
+class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel):
+ def __init__(self, config: Ernie4_5_MoeConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
- [Ernie4_5_MoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ [Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
- self.norm = Ernie4_5_MoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = Ernie4_5_MoERotaryEmbedding(config=config)
+ self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -650,14 +652,14 @@ def load_balancing_loss_func(
@auto_docstring
-class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin):
+class Ernie4_5_MoeForCausalLM(Ernie4_5_MoePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
- self.model = Ernie4_5_MoEModel(config)
+ self.model = Ernie4_5_MoeModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias)
@@ -745,4 +747,4 @@ class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin):
)
-__all__ = ["Ernie4_5_MoEForCausalLM", "Ernie4_5_MoEModel", "Ernie4_5_MoEPreTrainedModel"]
+__all__ = ["Ernie4_5_MoeForCausalLM", "Ernie4_5_MoeModel", "Ernie4_5_MoePreTrainedModel"]
diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py
index 3c4e068d37..76fec0ff3b 100644
--- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py
+++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py
@@ -24,26 +24,25 @@ from ...masking_utils import create_causal_mask
from ...modeling_outputs import MoeModelOutputWithPast
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
-from ...utils.generic import check_model_inputs
+from ...utils.generic import OutputRecorder, check_model_inputs
from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm
from ..mixtral.modeling_mixtral import (
MixtralForCausalLM,
- MixtralModel,
MixtralPreTrainedModel,
)
from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeMLP
-from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig
+from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
logger = logging.get_logger(__name__)
-class Ernie4_5_MoERMSNorm(LlamaRMSNorm):
+class Ernie4_5_MoeRMSNorm(LlamaRMSNorm):
pass
-class Ernie4_5_MoEMLP(Qwen3MoeMLP):
+class Ernie4_5_MoeMLP(Qwen3MoeMLP):
def __init__(self, config, intermediate_size=None):
super().__init__(config, intermediate_size)
@@ -52,12 +51,13 @@ class Ernie4_5_MoEMLP(Qwen3MoeMLP):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
-class Ernie4_5_MoERotaryEmbedding(Ernie4_5RotaryEmbedding):
- pass
+class Ernie4_5_MoeRotaryEmbedding(Ernie4_5RotaryEmbedding):
+ def __init__(self, config: Ernie4_5_MoeConfig, device=None):
+ super().__init__(config, device)
-class Ernie4_5_MoEAttention(LlamaAttention):
- def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int):
+class Ernie4_5_MoeAttention(LlamaAttention):
+ def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_dropout = 0.0
@@ -68,7 +68,7 @@ class Ernie4_5_MoEAttention(LlamaAttention):
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
-class Ernie4_5_MoEStatics(nn.Module):
+class Ernie4_5_MoeStatics(nn.Module):
"""
Stores MoE (Mixture of Experts) statistics
- Bias for the gating
@@ -95,7 +95,7 @@ class Ernie4_5_MoEStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
-class Ernie4_5_MoESparseMoeBlock(nn.Module):
+class Ernie4_5_MoeSparseMoeBlock(nn.Module):
"""
This implementation is
strictly equivalent to standard MoE with full capacity (no
@@ -116,19 +116,19 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
self.top_k = config.moe_k
# correction bias (yes it seems to be a typo with statics <> statistics)
- self.moe_statics = Ernie4_5_MoEStatics(config)
+ self.moe_statics = Ernie4_5_MoeStatics(config)
# gating
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.experts = nn.ModuleList(
- [Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
+ [Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
)
self.norm_min = config.moe_norm_min
# (optional) shared experts for all forwards
self.shared_experts = None
if config.moe_num_shared_experts > 0:
- self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
+ self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def forward(
self,
@@ -190,38 +190,63 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
return final_hidden_states, router_logits
-class Ernie4_5_MoEDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
+class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
def __init__(self, config, layer_idx):
nn.Module().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = Ernie4_5_MoEAttention(config, layer_idx)
+ self.self_attn = Ernie4_5_MoeAttention(config, layer_idx)
if (
((layer_idx + 1) % config.moe_layer_interval == 0)
and layer_idx >= config.moe_layer_start_index
and layer_idx <= config.moe_layer_end_index
):
- self.mlp = Ernie4_5_MoESparseMoeBlock(config)
+ self.mlp = Ernie4_5_MoeSparseMoeBlock(config)
else:
- self.mlp = Ernie4_5_MoEMLP(config)
+ self.mlp = Ernie4_5_MoeMLP(config)
- self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
- self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
@auto_docstring
-class Ernie4_5_MoEPreTrainedModel(MixtralPreTrainedModel):
+class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
+ config: Ernie4_5_MoeConfig
+ _no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
+ # Not supporting multi-token prediction (MTP) atm
+ _keys_to_ignore_on_load_unexpected = ["mtp"]
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1),
+ "hidden_states": Ernie4_5_MoeDecoderLayer,
+ "attentions": Ernie4_5_MoeAttention,
+ }
def _init_weights(self, module):
MixtralPreTrainedModel._init_weights(module)
- if isinstance(module, Ernie4_5_MoEStatics):
+ if isinstance(module, Ernie4_5_MoeStatics):
module.e_score_correction_bias.data.zero_()
@auto_docstring
-class Ernie4_5_MoEModel(MixtralModel):
+class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel):
+ def __init__(self, config: Ernie4_5_MoeConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
@check_model_inputs
@auto_docstring
def forward(
@@ -287,10 +312,10 @@ class Ernie4_5_MoEModel(MixtralModel):
@auto_docstring
-class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel):
+class Ernie4_5_MoeForCausalLM(MixtralForCausalLM, Ernie4_5_MoePreTrainedModel):
def __init__(self, config):
- Ernie4_5_MoEPreTrainedModel().__init__(config)
- self.model = Ernie4_5_MoEModel(config)
+ Ernie4_5_MoePreTrainedModel().__init__(config)
+ self.model = Ernie4_5_MoeModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias)
@@ -314,7 +339,7 @@ class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel):
__all__ = [
- "Ernie4_5_MoEForCausalLM",
- "Ernie4_5_MoEModel",
- "Ernie4_5_MoEPreTrainedModel",
+ "Ernie4_5_MoeForCausalLM",
+ "Ernie4_5_MoeModel",
+ "Ernie4_5_MoePreTrainedModel",
]
diff --git a/tests/models/ernie4_5/test_modeling_ernie4_5.py b/tests/models/ernie4_5/test_modeling_ernie4_5.py
index 1c5bffa2c6..ea5a91804e 100644
--- a/tests/models/ernie4_5/test_modeling_ernie4_5.py
+++ b/tests/models/ernie4_5/test_modeling_ernie4_5.py
@@ -102,7 +102,6 @@ class Ernie4_5IntegrationTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3")
model = Ernie4_5ForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-0.3B-PT",
- revision="refs/pr/3",
device_map="auto",
torch_dtype=torch.bfloat16,
)
diff --git a/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py b/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py
index b8a8130155..393f8ee5aa 100644
--- a/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py
+++ b/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py
@@ -18,7 +18,7 @@ import unittest
import pytest
-from transformers import Ernie4_5_MoEConfig, is_torch_available
+from transformers import Ernie4_5_MoeConfig, is_torch_available
from transformers.testing_utils import (
cleanup,
is_flaky,
@@ -38,33 +38,33 @@ if is_torch_available():
from transformers import (
AutoTokenizer,
- Ernie4_5_MoEForCausalLM,
- Ernie4_5_MoEModel,
+ Ernie4_5_MoeForCausalLM,
+ Ernie4_5_MoeModel,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
-class Ernie4_5_MoEModelTester(CausalLMModelTester):
- config_class = Ernie4_5_MoEConfig
+class Ernie4_5_MoeModelTester(CausalLMModelTester):
+ config_class = Ernie4_5_MoeConfig
if is_torch_available():
- base_model_class = Ernie4_5_MoEModel
- causal_lm_class = Ernie4_5_MoEForCausalLM
+ base_model_class = Ernie4_5_MoeModel
+ causal_lm_class = Ernie4_5_MoeForCausalLM
@require_torch
-class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
+class Ernie4_5_MoeModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (
(
- Ernie4_5_MoEModel,
- Ernie4_5_MoEForCausalLM,
+ Ernie4_5_MoeModel,
+ Ernie4_5_MoeForCausalLM,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
- "feature-extraction": Ernie4_5_MoEModel,
- "text-generation": Ernie4_5_MoEForCausalLM,
+ "feature-extraction": Ernie4_5_MoeModel,
+ "text-generation": Ernie4_5_MoeForCausalLM,
}
if is_torch_available()
else {}
@@ -73,7 +73,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
test_headmasking = False
test_pruning = False
test_all_params_have_gradient = False
- model_tester_class = Ernie4_5_MoEModelTester
+ model_tester_class = Ernie4_5_MoeModelTester
@require_flash_attn
@require_torch_gpu
@@ -82,7 +82,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
+ if not model_class._supports_flash_attn:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -123,7 +123,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
config.output_router_logits = True
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
- model = Ernie4_5_MoEForCausalLM(config)
+ model = Ernie4_5_MoeForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask)
@@ -153,7 +153,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
@require_torch_multi_accelerator
@require_torch_large_accelerator
@require_torch
-class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
+class Ernie4_5_MoeIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = None
@@ -169,9 +169,8 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
@classmethod
def get_model(cls):
if cls.model is None:
- cls.model = Ernie4_5_MoEForCausalLM.from_pretrained(
+ cls.model = Ernie4_5_MoeForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-21B-A3B-PT",
- revision="refs/pr/11",
device_map="auto",
load_in_4bit=True,
)
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index 08e3f26245..d8bd32847e 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -33,7 +33,7 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
"Ernie4_5Config": ["tie_word_embeddings"],
- "Ernie4_5_MoEConfig": ["tie_word_embeddings"],
+ "Ernie4_5_MoeConfig": ["tie_word_embeddings"],
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
# used internally during generation to provide the custom logit processors with their necessary information
"DiaConfig": [