[Ernie 4.5] Post merge adaptations (#39664)
* ernie 4.5 fixes * Apply style fixes * fix --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,7 @@ The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.b
|
||||
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
|
||||
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core.
|
||||
|
||||
Other models from the family can be found at [Ernie 4.5 MoE](./ernie4_5_moe.md).
|
||||
Other models from the family can be found at [Ernie 4.5 Moe](./ernie4_5_moe.md).
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>
|
||||
|
||||
@@ -23,11 +23,11 @@ rendered properly in your Markdown viewer.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# Ernie 4.5 MoE
|
||||
# Ernie 4.5 Moe
|
||||
|
||||
## Overview
|
||||
|
||||
The Ernie 4.5 MoE model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
|
||||
The Ernie 4.5 Moe model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
|
||||
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
|
||||
model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters.
|
||||
It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared
|
||||
@@ -167,17 +167,17 @@ This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV).
|
||||
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE).
|
||||
|
||||
|
||||
## Ernie4_5_MoEConfig
|
||||
## Ernie4_5_MoeConfig
|
||||
|
||||
[[autodoc]] Ernie4_5_MoEConfig
|
||||
[[autodoc]] Ernie4_5_MoeConfig
|
||||
|
||||
## Ernie4_5_MoEModel
|
||||
## Ernie4_5_MoeModel
|
||||
|
||||
[[autodoc]] Ernie4_5_MoEModel
|
||||
[[autodoc]] Ernie4_5_MoeModel
|
||||
- forward
|
||||
|
||||
## Ernie4_5_MoEForCausalLM
|
||||
## Ernie4_5_MoeForCausalLM
|
||||
|
||||
[[autodoc]] Ernie4_5_MoEForCausalLM
|
||||
[[autodoc]] Ernie4_5_MoeForCausalLM
|
||||
- forward
|
||||
- generate
|
||||
|
||||
@@ -130,7 +130,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("eomt", "EomtConfig"),
|
||||
("ernie", "ErnieConfig"),
|
||||
("ernie4_5", "Ernie4_5Config"),
|
||||
("ernie4_5_moe", "Ernie4_5_MoEConfig"),
|
||||
("ernie4_5_moe", "Ernie4_5_MoeConfig"),
|
||||
("ernie_m", "ErnieMConfig"),
|
||||
("esm", "EsmConfig"),
|
||||
("falcon", "FalconConfig"),
|
||||
|
||||
@@ -121,7 +121,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("encodec", "EncodecModel"),
|
||||
("ernie", "ErnieModel"),
|
||||
("ernie4_5", "Ernie4_5Model"),
|
||||
("ernie4_5_moe", "Ernie4_5_MoEModel"),
|
||||
("ernie4_5_moe", "Ernie4_5_MoeModel"),
|
||||
("ernie_m", "ErnieMModel"),
|
||||
("esm", "EsmModel"),
|
||||
("falcon", "FalconModel"),
|
||||
@@ -597,7 +597,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("emu3", "Emu3ForCausalLM"),
|
||||
("ernie", "ErnieForCausalLM"),
|
||||
("ernie4_5", "Ernie4_5ForCausalLM"),
|
||||
("ernie4_5_moe", "Ernie4_5_MoEForCausalLM"),
|
||||
("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
|
||||
("falcon", "FalconForCausalLM"),
|
||||
("falcon_h1", "FalconH1ForCausalLM"),
|
||||
("falcon_mamba", "FalconMambaForCausalLM"),
|
||||
|
||||
@@ -21,9 +21,9 @@ from ...utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Ernie4_5_MoEConfig(PretrainedConfig):
|
||||
class Ernie4_5_MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Ernie4_5_MoEModel`]. It is used to instantiate a
|
||||
This is the configuration class to store the configuration of a [`Ernie4_5_MoeModel`]. It is used to instantiate a
|
||||
Ernie 4.5 MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of [baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT).
|
||||
|
||||
@@ -34,7 +34,7 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 103424):
|
||||
Vocabulary size of the Ernie 4.5 MoE model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Ernie4_5_MoEModel`]
|
||||
`inputs_ids` passed when calling [`Ernie4_5_MoeModel`]
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
@@ -133,13 +133,13 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
|
||||
The aux loss factor for the total loss.
|
||||
|
||||
```python
|
||||
>>> from transformers import Ernie4_5_MoEModel, Ernie4_5_MoEConfig
|
||||
>>> from transformers import Ernie4_5_MoeModel, Ernie4_5_MoEConfig
|
||||
|
||||
>>> # Initializing a Ernie4_5_MoE style configuration
|
||||
>>> configuration = Ernie4_5_MoEConfig()
|
||||
|
||||
>>> # Initializing a model from the ERNIE-4.5-21B-A3B style configuration
|
||||
>>> model = Ernie4_5_MoEModel(configuration)
|
||||
>>> model = Ernie4_5_MoeModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
@@ -251,4 +251,4 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Ernie4_5_MoEConfig"]
|
||||
__all__ = ["Ernie4_5_MoeConfig"]
|
||||
|
||||
@@ -37,14 +37,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import OutputRecorder, check_model_inputs
|
||||
from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig
|
||||
from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class Ernie4_5_MoERMSNorm(nn.Module):
|
||||
class Ernie4_5_MoeRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Ernie4_5_MoERMSNorm is equivalent to T5LayerNorm
|
||||
Ernie4_5_MoeRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
@@ -61,7 +61,7 @@ class Ernie4_5_MoERMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Ernie4_5_MoEMLP(nn.Module):
|
||||
class Ernie4_5_MoeMLP(nn.Module):
|
||||
def __init__(self, config, intermediate_size=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -78,8 +78,8 @@ class Ernie4_5_MoEMLP(nn.Module):
|
||||
return down_proj
|
||||
|
||||
|
||||
class Ernie4_5_MoERotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Ernie4_5_MoEConfig, device=None):
|
||||
class Ernie4_5_MoeRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Ernie4_5_MoeConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
||||
@@ -194,10 +194,10 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Ernie4_5_MoEAttention(nn.Module):
|
||||
class Ernie4_5_MoeAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int):
|
||||
def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@@ -257,7 +257,7 @@ class Ernie4_5_MoEAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Ernie4_5_MoEStatics(nn.Module):
|
||||
class Ernie4_5_MoeStatics(nn.Module):
|
||||
"""
|
||||
Stores MoE (Mixture of Experts) statistics
|
||||
- Bias for the gating
|
||||
@@ -284,7 +284,7 @@ class Ernie4_5_MoEStatics(nn.Module):
|
||||
return hidden_states + self.e_score_correction_bias.squeeze()
|
||||
|
||||
|
||||
class Ernie4_5_MoESparseMoeBlock(nn.Module):
|
||||
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
"""
|
||||
This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
@@ -305,19 +305,19 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
|
||||
self.top_k = config.moe_k
|
||||
|
||||
# correction bias (yes it seems to be a typo with statics <> statistics)
|
||||
self.moe_statics = Ernie4_5_MoEStatics(config)
|
||||
self.moe_statics = Ernie4_5_MoeStatics(config)
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
|
||||
self.experts = nn.ModuleList(
|
||||
[Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
|
||||
[Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
|
||||
)
|
||||
self.norm_min = config.moe_norm_min
|
||||
|
||||
# (optional) shared experts for all forwards
|
||||
self.shared_experts = None
|
||||
if config.moe_num_shared_experts > 0:
|
||||
self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
|
||||
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -379,24 +379,24 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer):
|
||||
class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Ernie4_5_MoEAttention(config, layer_idx)
|
||||
self.self_attn = Ernie4_5_MoeAttention(config, layer_idx)
|
||||
|
||||
if (
|
||||
((layer_idx + 1) % config.moe_layer_interval == 0)
|
||||
and layer_idx >= config.moe_layer_start_index
|
||||
and layer_idx <= config.moe_layer_end_index
|
||||
):
|
||||
self.mlp = Ernie4_5_MoESparseMoeBlock(config)
|
||||
self.mlp = Ernie4_5_MoeSparseMoeBlock(config)
|
||||
else:
|
||||
self.mlp = Ernie4_5_MoEMLP(config)
|
||||
self.mlp = Ernie4_5_MoeMLP(config)
|
||||
|
||||
self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -461,11 +461,11 @@ class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ernie4_5_MoEPreTrainedModel(PreTrainedModel):
|
||||
config: Ernie4_5_MoEConfig
|
||||
class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
|
||||
config: Ernie4_5_MoeConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Ernie4_5_MoEDecoderLayer"]
|
||||
_no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
@@ -473,31 +473,33 @@ class Ernie4_5_MoEPreTrainedModel(PreTrainedModel):
|
||||
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"router_logits": OutputRecorder(Ernie4_5_MoESparseMoeBlock, index=1),
|
||||
"hidden_states": Ernie4_5_MoEDecoderLayer,
|
||||
"attentions": Ernie4_5_MoEAttention,
|
||||
"router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1),
|
||||
"hidden_states": Ernie4_5_MoeDecoderLayer,
|
||||
"attentions": Ernie4_5_MoeAttention,
|
||||
}
|
||||
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
|
||||
# Not supporting multi-token prediction (MTP) atm
|
||||
_keys_to_ignore_on_load_unexpected = ["mtp"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, Ernie4_5_MoEStatics):
|
||||
if isinstance(module, Ernie4_5_MoeStatics):
|
||||
module.e_score_correction_bias.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ernie4_5_MoEModel(Ernie4_5_MoEPreTrainedModel):
|
||||
def __init__(self, config: Ernie4_5_MoEConfig):
|
||||
class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel):
|
||||
def __init__(self, config: Ernie4_5_MoeConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[Ernie4_5_MoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
[Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = Ernie4_5_MoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Ernie4_5_MoERotaryEmbedding(config=config)
|
||||
self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@@ -650,14 +652,14 @@ def load_balancing_loss_func(
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin):
|
||||
class Ernie4_5_MoeForCausalLM(Ernie4_5_MoePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Ernie4_5_MoEModel(config)
|
||||
self.model = Ernie4_5_MoeModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias)
|
||||
|
||||
@@ -745,4 +747,4 @@ class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Ernie4_5_MoEForCausalLM", "Ernie4_5_MoEModel", "Ernie4_5_MoEPreTrainedModel"]
|
||||
__all__ = ["Ernie4_5_MoeForCausalLM", "Ernie4_5_MoeModel", "Ernie4_5_MoePreTrainedModel"]
|
||||
|
||||
@@ -24,26 +24,25 @@ from ...masking_utils import create_causal_mask
|
||||
from ...modeling_outputs import MoeModelOutputWithPast
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from ...utils.generic import OutputRecorder, check_model_inputs
|
||||
from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401
|
||||
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm
|
||||
from ..mixtral.modeling_mixtral import (
|
||||
MixtralForCausalLM,
|
||||
MixtralModel,
|
||||
MixtralPreTrainedModel,
|
||||
)
|
||||
from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeMLP
|
||||
from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig
|
||||
from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Ernie4_5_MoERMSNorm(LlamaRMSNorm):
|
||||
class Ernie4_5_MoeRMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Ernie4_5_MoEMLP(Qwen3MoeMLP):
|
||||
class Ernie4_5_MoeMLP(Qwen3MoeMLP):
|
||||
def __init__(self, config, intermediate_size=None):
|
||||
super().__init__(config, intermediate_size)
|
||||
|
||||
@@ -52,12 +51,13 @@ class Ernie4_5_MoEMLP(Qwen3MoeMLP):
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
||||
|
||||
|
||||
class Ernie4_5_MoERotaryEmbedding(Ernie4_5RotaryEmbedding):
|
||||
pass
|
||||
class Ernie4_5_MoeRotaryEmbedding(Ernie4_5RotaryEmbedding):
|
||||
def __init__(self, config: Ernie4_5_MoeConfig, device=None):
|
||||
super().__init__(config, device)
|
||||
|
||||
|
||||
class Ernie4_5_MoEAttention(LlamaAttention):
|
||||
def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int):
|
||||
class Ernie4_5_MoeAttention(LlamaAttention):
|
||||
def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
self.attention_dropout = 0.0
|
||||
@@ -68,7 +68,7 @@ class Ernie4_5_MoEAttention(LlamaAttention):
|
||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
|
||||
|
||||
|
||||
class Ernie4_5_MoEStatics(nn.Module):
|
||||
class Ernie4_5_MoeStatics(nn.Module):
|
||||
"""
|
||||
Stores MoE (Mixture of Experts) statistics
|
||||
- Bias for the gating
|
||||
@@ -95,7 +95,7 @@ class Ernie4_5_MoEStatics(nn.Module):
|
||||
return hidden_states + self.e_score_correction_bias.squeeze()
|
||||
|
||||
|
||||
class Ernie4_5_MoESparseMoeBlock(nn.Module):
|
||||
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
"""
|
||||
This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
@@ -116,19 +116,19 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
|
||||
self.top_k = config.moe_k
|
||||
|
||||
# correction bias (yes it seems to be a typo with statics <> statistics)
|
||||
self.moe_statics = Ernie4_5_MoEStatics(config)
|
||||
self.moe_statics = Ernie4_5_MoeStatics(config)
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
|
||||
self.experts = nn.ModuleList(
|
||||
[Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
|
||||
[Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
|
||||
)
|
||||
self.norm_min = config.moe_norm_min
|
||||
|
||||
# (optional) shared experts for all forwards
|
||||
self.shared_experts = None
|
||||
if config.moe_num_shared_experts > 0:
|
||||
self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
|
||||
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -190,38 +190,63 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
class Ernie4_5_MoEDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
|
||||
class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
nn.Module().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Ernie4_5_MoEAttention(config, layer_idx)
|
||||
self.self_attn = Ernie4_5_MoeAttention(config, layer_idx)
|
||||
|
||||
if (
|
||||
((layer_idx + 1) % config.moe_layer_interval == 0)
|
||||
and layer_idx >= config.moe_layer_start_index
|
||||
and layer_idx <= config.moe_layer_end_index
|
||||
):
|
||||
self.mlp = Ernie4_5_MoESparseMoeBlock(config)
|
||||
self.mlp = Ernie4_5_MoeSparseMoeBlock(config)
|
||||
else:
|
||||
self.mlp = Ernie4_5_MoEMLP(config)
|
||||
self.mlp = Ernie4_5_MoeMLP(config)
|
||||
|
||||
self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ernie4_5_MoEPreTrainedModel(MixtralPreTrainedModel):
|
||||
class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
|
||||
config: Ernie4_5_MoeConfig
|
||||
_no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
|
||||
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
|
||||
# Not supporting multi-token prediction (MTP) atm
|
||||
_keys_to_ignore_on_load_unexpected = ["mtp"]
|
||||
_can_record_outputs = {
|
||||
"router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1),
|
||||
"hidden_states": Ernie4_5_MoeDecoderLayer,
|
||||
"attentions": Ernie4_5_MoeAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module):
|
||||
MixtralPreTrainedModel._init_weights(module)
|
||||
if isinstance(module, Ernie4_5_MoEStatics):
|
||||
if isinstance(module, Ernie4_5_MoeStatics):
|
||||
module.e_score_correction_bias.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ernie4_5_MoEModel(MixtralModel):
|
||||
class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel):
|
||||
def __init__(self, config: Ernie4_5_MoeConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
@@ -287,10 +312,10 @@ class Ernie4_5_MoEModel(MixtralModel):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel):
|
||||
class Ernie4_5_MoeForCausalLM(MixtralForCausalLM, Ernie4_5_MoePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
Ernie4_5_MoEPreTrainedModel().__init__(config)
|
||||
self.model = Ernie4_5_MoEModel(config)
|
||||
Ernie4_5_MoePreTrainedModel().__init__(config)
|
||||
self.model = Ernie4_5_MoeModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias)
|
||||
|
||||
@@ -314,7 +339,7 @@ class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel):
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Ernie4_5_MoEForCausalLM",
|
||||
"Ernie4_5_MoEModel",
|
||||
"Ernie4_5_MoEPreTrainedModel",
|
||||
"Ernie4_5_MoeForCausalLM",
|
||||
"Ernie4_5_MoeModel",
|
||||
"Ernie4_5_MoePreTrainedModel",
|
||||
]
|
||||
|
||||
@@ -102,7 +102,6 @@ class Ernie4_5IntegrationTest(unittest.TestCase):
|
||||
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3")
|
||||
model = Ernie4_5ForCausalLM.from_pretrained(
|
||||
"baidu/ERNIE-4.5-0.3B-PT",
|
||||
revision="refs/pr/3",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import Ernie4_5_MoEConfig, is_torch_available
|
||||
from transformers import Ernie4_5_MoeConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
is_flaky,
|
||||
@@ -38,33 +38,33 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
Ernie4_5_MoEForCausalLM,
|
||||
Ernie4_5_MoEModel,
|
||||
Ernie4_5_MoeForCausalLM,
|
||||
Ernie4_5_MoeModel,
|
||||
)
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
|
||||
|
||||
class Ernie4_5_MoEModelTester(CausalLMModelTester):
|
||||
config_class = Ernie4_5_MoEConfig
|
||||
class Ernie4_5_MoeModelTester(CausalLMModelTester):
|
||||
config_class = Ernie4_5_MoeConfig
|
||||
if is_torch_available():
|
||||
base_model_class = Ernie4_5_MoEModel
|
||||
causal_lm_class = Ernie4_5_MoEForCausalLM
|
||||
base_model_class = Ernie4_5_MoeModel
|
||||
causal_lm_class = Ernie4_5_MoeForCausalLM
|
||||
|
||||
|
||||
@require_torch
|
||||
class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
class Ernie4_5_MoeModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
Ernie4_5_MoEModel,
|
||||
Ernie4_5_MoEForCausalLM,
|
||||
Ernie4_5_MoeModel,
|
||||
Ernie4_5_MoeForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Ernie4_5_MoEModel,
|
||||
"text-generation": Ernie4_5_MoEForCausalLM,
|
||||
"feature-extraction": Ernie4_5_MoeModel,
|
||||
"text-generation": Ernie4_5_MoeForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@@ -73,7 +73,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
test_all_params_have_gradient = False
|
||||
model_tester_class = Ernie4_5_MoEModelTester
|
||||
model_tester_class = Ernie4_5_MoeModelTester
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@@ -82,7 +82,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
@slow
|
||||
def test_flash_attn_2_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -123,7 +123,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
config.output_router_logits = True
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
model = Ernie4_5_MoEForCausalLM(config)
|
||||
model = Ernie4_5_MoeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask)
|
||||
@@ -153,7 +153,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
@require_torch_multi_accelerator
|
||||
@require_torch_large_accelerator
|
||||
@require_torch
|
||||
class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
|
||||
class Ernie4_5_MoeIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = None
|
||||
@@ -169,9 +169,8 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def get_model(cls):
|
||||
if cls.model is None:
|
||||
cls.model = Ernie4_5_MoEForCausalLM.from_pretrained(
|
||||
cls.model = Ernie4_5_MoeForCausalLM.from_pretrained(
|
||||
"baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
revision="refs/pr/11",
|
||||
device_map="auto",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
|
||||
SPECIAL_CASES_TO_ALLOW = {
|
||||
"Ernie4_5Config": ["tie_word_embeddings"],
|
||||
"Ernie4_5_MoEConfig": ["tie_word_embeddings"],
|
||||
"Ernie4_5_MoeConfig": ["tie_word_embeddings"],
|
||||
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
|
||||
# used internally during generation to provide the custom logit processors with their necessary information
|
||||
"DiaConfig": [
|
||||
|
||||
Reference in New Issue
Block a user