[Ernie 4.5] Post merge adaptations (#39664)

* ernie 4.5 fixes

* Apply style fixes

* fix

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Anton Vlasjuk
2025-07-25 17:36:18 +02:00
committed by GitHub
parent 5d0ba3e479
commit a91653561e
10 changed files with 126 additions and 101 deletions

View File

@@ -31,7 +31,7 @@ The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.b
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core.
Other models from the family can be found at [Ernie 4.5 MoE](./ernie4_5_moe.md).
Other models from the family can be found at [Ernie 4.5 Moe](./ernie4_5_moe.md).
<div class="flex justify-center">
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>

View File

@@ -23,11 +23,11 @@ rendered properly in your Markdown viewer.
</div>
</div>
# Ernie 4.5 MoE
# Ernie 4.5 Moe
## Overview
The Ernie 4.5 MoE model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
The Ernie 4.5 Moe model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters.
It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared
@@ -167,17 +167,17 @@ This model was contributed by [Anton Vlasjuk](https://huggingface.co/AntonV).
The original code can be found [here](https://github.com/PaddlePaddle/ERNIE).
## Ernie4_5_MoEConfig
## Ernie4_5_MoeConfig
[[autodoc]] Ernie4_5_MoEConfig
[[autodoc]] Ernie4_5_MoeConfig
## Ernie4_5_MoEModel
## Ernie4_5_MoeModel
[[autodoc]] Ernie4_5_MoEModel
[[autodoc]] Ernie4_5_MoeModel
- forward
## Ernie4_5_MoEForCausalLM
## Ernie4_5_MoeForCausalLM
[[autodoc]] Ernie4_5_MoEForCausalLM
[[autodoc]] Ernie4_5_MoeForCausalLM
- forward
- generate

View File

@@ -130,7 +130,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("eomt", "EomtConfig"),
("ernie", "ErnieConfig"),
("ernie4_5", "Ernie4_5Config"),
("ernie4_5_moe", "Ernie4_5_MoEConfig"),
("ernie4_5_moe", "Ernie4_5_MoeConfig"),
("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"),
("falcon", "FalconConfig"),

View File

@@ -121,7 +121,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("encodec", "EncodecModel"),
("ernie", "ErnieModel"),
("ernie4_5", "Ernie4_5Model"),
("ernie4_5_moe", "Ernie4_5_MoEModel"),
("ernie4_5_moe", "Ernie4_5_MoeModel"),
("ernie_m", "ErnieMModel"),
("esm", "EsmModel"),
("falcon", "FalconModel"),
@@ -597,7 +597,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3ForCausalLM"),
("ernie", "ErnieForCausalLM"),
("ernie4_5", "Ernie4_5ForCausalLM"),
("ernie4_5_moe", "Ernie4_5_MoEForCausalLM"),
("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
("falcon", "FalconForCausalLM"),
("falcon_h1", "FalconH1ForCausalLM"),
("falcon_mamba", "FalconMambaForCausalLM"),

View File

@@ -21,9 +21,9 @@ from ...utils import logging
logger = logging.get_logger(__name__)
class Ernie4_5_MoEConfig(PretrainedConfig):
class Ernie4_5_MoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Ernie4_5_MoEModel`]. It is used to instantiate a
This is the configuration class to store the configuration of a [`Ernie4_5_MoeModel`]. It is used to instantiate a
Ernie 4.5 MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of [baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT).
@@ -34,7 +34,7 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 103424):
Vocabulary size of the Ernie 4.5 MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Ernie4_5_MoEModel`]
`inputs_ids` passed when calling [`Ernie4_5_MoeModel`]
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
@@ -133,13 +133,13 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
The aux loss factor for the total loss.
```python
>>> from transformers import Ernie4_5_MoEModel, Ernie4_5_MoEConfig
>>> from transformers import Ernie4_5_MoeModel, Ernie4_5_MoEConfig
>>> # Initializing a Ernie4_5_MoE style configuration
>>> configuration = Ernie4_5_MoEConfig()
>>> # Initializing a model from the ERNIE-4.5-21B-A3B style configuration
>>> model = Ernie4_5_MoEModel(configuration)
>>> model = Ernie4_5_MoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
@@ -251,4 +251,4 @@ class Ernie4_5_MoEConfig(PretrainedConfig):
)
__all__ = ["Ernie4_5_MoEConfig"]
__all__ = ["Ernie4_5_MoeConfig"]

View File

@@ -37,14 +37,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import OutputRecorder, check_model_inputs
from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig
from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
@use_kernel_forward_from_hub("RMSNorm")
class Ernie4_5_MoERMSNorm(nn.Module):
class Ernie4_5_MoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Ernie4_5_MoERMSNorm is equivalent to T5LayerNorm
Ernie4_5_MoeRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -61,7 +61,7 @@ class Ernie4_5_MoERMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Ernie4_5_MoEMLP(nn.Module):
class Ernie4_5_MoeMLP(nn.Module):
def __init__(self, config, intermediate_size=None):
super().__init__()
self.config = config
@@ -78,8 +78,8 @@ class Ernie4_5_MoEMLP(nn.Module):
return down_proj
class Ernie4_5_MoERotaryEmbedding(nn.Module):
def __init__(self, config: Ernie4_5_MoEConfig, device=None):
class Ernie4_5_MoeRotaryEmbedding(nn.Module):
def __init__(self, config: Ernie4_5_MoeConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
@@ -194,10 +194,10 @@ def eager_attention_forward(
return attn_output, attn_weights
class Ernie4_5_MoEAttention(nn.Module):
class Ernie4_5_MoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int):
def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@@ -257,7 +257,7 @@ class Ernie4_5_MoEAttention(nn.Module):
return attn_output, attn_weights
class Ernie4_5_MoEStatics(nn.Module):
class Ernie4_5_MoeStatics(nn.Module):
"""
Stores MoE (Mixture of Experts) statistics
- Bias for the gating
@@ -284,7 +284,7 @@ class Ernie4_5_MoEStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoESparseMoeBlock(nn.Module):
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
"""
This implementation is
strictly equivalent to standard MoE with full capacity (no
@@ -305,19 +305,19 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
self.top_k = config.moe_k
# correction bias (yes it seems to be a typo with statics <> statistics)
self.moe_statics = Ernie4_5_MoEStatics(config)
self.moe_statics = Ernie4_5_MoeStatics(config)
# gating
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.experts = nn.ModuleList(
[Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
[Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
)
self.norm_min = config.moe_norm_min
# (optional) shared experts for all forwards
self.shared_experts = None
if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def forward(
self,
@@ -379,24 +379,24 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
return final_hidden_states, router_logits
class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer):
class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Ernie4_5_MoEAttention(config, layer_idx)
self.self_attn = Ernie4_5_MoeAttention(config, layer_idx)
if (
((layer_idx + 1) % config.moe_layer_interval == 0)
and layer_idx >= config.moe_layer_start_index
and layer_idx <= config.moe_layer_end_index
):
self.mlp = Ernie4_5_MoESparseMoeBlock(config)
self.mlp = Ernie4_5_MoeSparseMoeBlock(config)
else:
self.mlp = Ernie4_5_MoEMLP(config)
self.mlp = Ernie4_5_MoeMLP(config)
self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(
self,
@@ -461,11 +461,11 @@ class Ernie4_5_MoEDecoderLayer(GradientCheckpointingLayer):
@auto_docstring
class Ernie4_5_MoEPreTrainedModel(PreTrainedModel):
config: Ernie4_5_MoEConfig
class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
config: Ernie4_5_MoeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Ernie4_5_MoEDecoderLayer"]
_no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
@@ -473,31 +473,33 @@ class Ernie4_5_MoEPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoESparseMoeBlock, index=1),
"hidden_states": Ernie4_5_MoEDecoderLayer,
"attentions": Ernie4_5_MoEAttention,
"router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1),
"hidden_states": Ernie4_5_MoeDecoderLayer,
"attentions": Ernie4_5_MoeAttention,
}
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
# Not supporting multi-token prediction (MTP) atm
_keys_to_ignore_on_load_unexpected = ["mtp"]
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Ernie4_5_MoEStatics):
if isinstance(module, Ernie4_5_MoeStatics):
module.e_score_correction_bias.data.zero_()
@auto_docstring
class Ernie4_5_MoEModel(Ernie4_5_MoEPreTrainedModel):
def __init__(self, config: Ernie4_5_MoEConfig):
class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel):
def __init__(self, config: Ernie4_5_MoeConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Ernie4_5_MoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
[Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Ernie4_5_MoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Ernie4_5_MoERotaryEmbedding(config=config)
self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -650,14 +652,14 @@ def load_balancing_loss_func(
@auto_docstring
class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin):
class Ernie4_5_MoeForCausalLM(Ernie4_5_MoePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Ernie4_5_MoEModel(config)
self.model = Ernie4_5_MoeModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias)
@@ -745,4 +747,4 @@ class Ernie4_5_MoEForCausalLM(Ernie4_5_MoEPreTrainedModel, GenerationMixin):
)
__all__ = ["Ernie4_5_MoEForCausalLM", "Ernie4_5_MoEModel", "Ernie4_5_MoEPreTrainedModel"]
__all__ = ["Ernie4_5_MoeForCausalLM", "Ernie4_5_MoeModel", "Ernie4_5_MoePreTrainedModel"]

View File

@@ -24,26 +24,25 @@ from ...masking_utils import create_causal_mask
from ...modeling_outputs import MoeModelOutputWithPast
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import check_model_inputs
from ...utils.generic import OutputRecorder, check_model_inputs
from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm
from ..mixtral.modeling_mixtral import (
MixtralForCausalLM,
MixtralModel,
MixtralPreTrainedModel,
)
from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeMLP
from .configuration_ernie4_5_moe import Ernie4_5_MoEConfig
from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
logger = logging.get_logger(__name__)
class Ernie4_5_MoERMSNorm(LlamaRMSNorm):
class Ernie4_5_MoeRMSNorm(LlamaRMSNorm):
pass
class Ernie4_5_MoEMLP(Qwen3MoeMLP):
class Ernie4_5_MoeMLP(Qwen3MoeMLP):
def __init__(self, config, intermediate_size=None):
super().__init__(config, intermediate_size)
@@ -52,12 +51,13 @@ class Ernie4_5_MoEMLP(Qwen3MoeMLP):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
class Ernie4_5_MoERotaryEmbedding(Ernie4_5RotaryEmbedding):
pass
class Ernie4_5_MoeRotaryEmbedding(Ernie4_5RotaryEmbedding):
def __init__(self, config: Ernie4_5_MoeConfig, device=None):
super().__init__(config, device)
class Ernie4_5_MoEAttention(LlamaAttention):
def __init__(self, config: Ernie4_5_MoEConfig, layer_idx: int):
class Ernie4_5_MoeAttention(LlamaAttention):
def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_dropout = 0.0
@@ -68,7 +68,7 @@ class Ernie4_5_MoEAttention(LlamaAttention):
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
class Ernie4_5_MoEStatics(nn.Module):
class Ernie4_5_MoeStatics(nn.Module):
"""
Stores MoE (Mixture of Experts) statistics
- Bias for the gating
@@ -95,7 +95,7 @@ class Ernie4_5_MoEStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoESparseMoeBlock(nn.Module):
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
"""
This implementation is
strictly equivalent to standard MoE with full capacity (no
@@ -116,19 +116,19 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
self.top_k = config.moe_k
# correction bias (yes it seems to be a typo with statics <> statistics)
self.moe_statics = Ernie4_5_MoEStatics(config)
self.moe_statics = Ernie4_5_MoeStatics(config)
# gating
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.experts = nn.ModuleList(
[Ernie4_5_MoEMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
[Ernie4_5_MoeMLP(config, config.moe_intermediate_size) for _ in range(config.moe_num_experts)]
)
self.norm_min = config.moe_norm_min
# (optional) shared experts for all forwards
self.shared_experts = None
if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoEMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def forward(
self,
@@ -190,38 +190,63 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
return final_hidden_states, router_logits
class Ernie4_5_MoEDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
def __init__(self, config, layer_idx):
nn.Module().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Ernie4_5_MoEAttention(config, layer_idx)
self.self_attn = Ernie4_5_MoeAttention(config, layer_idx)
if (
((layer_idx + 1) % config.moe_layer_interval == 0)
and layer_idx >= config.moe_layer_start_index
and layer_idx <= config.moe_layer_end_index
):
self.mlp = Ernie4_5_MoESparseMoeBlock(config)
self.mlp = Ernie4_5_MoeSparseMoeBlock(config)
else:
self.mlp = Ernie4_5_MoEMLP(config)
self.mlp = Ernie4_5_MoeMLP(config)
self.input_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = Ernie4_5_MoERMSNorm(config.hidden_size, config.rms_norm_eps)
self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps)
@auto_docstring
class Ernie4_5_MoEPreTrainedModel(MixtralPreTrainedModel):
class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
config: Ernie4_5_MoeConfig
_no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
# Not supporting multi-token prediction (MTP) atm
_keys_to_ignore_on_load_unexpected = ["mtp"]
_can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoeSparseMoeBlock, index=1),
"hidden_states": Ernie4_5_MoeDecoderLayer,
"attentions": Ernie4_5_MoeAttention,
}
def _init_weights(self, module):
MixtralPreTrainedModel._init_weights(module)
if isinstance(module, Ernie4_5_MoEStatics):
if isinstance(module, Ernie4_5_MoeStatics):
module.e_score_correction_bias.data.zero_()
@auto_docstring
class Ernie4_5_MoEModel(MixtralModel):
class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel):
def __init__(self, config: Ernie4_5_MoeConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Ernie4_5_MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Ernie4_5_MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Ernie4_5_MoeRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs
@auto_docstring
def forward(
@@ -287,10 +312,10 @@ class Ernie4_5_MoEModel(MixtralModel):
@auto_docstring
class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel):
class Ernie4_5_MoeForCausalLM(MixtralForCausalLM, Ernie4_5_MoePreTrainedModel):
def __init__(self, config):
Ernie4_5_MoEPreTrainedModel().__init__(config)
self.model = Ernie4_5_MoEModel(config)
Ernie4_5_MoePreTrainedModel().__init__(config)
self.model = Ernie4_5_MoeModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.use_bias)
@@ -314,7 +339,7 @@ class Ernie4_5_MoEForCausalLM(MixtralForCausalLM, Ernie4_5_MoEPreTrainedModel):
__all__ = [
"Ernie4_5_MoEForCausalLM",
"Ernie4_5_MoEModel",
"Ernie4_5_MoEPreTrainedModel",
"Ernie4_5_MoeForCausalLM",
"Ernie4_5_MoeModel",
"Ernie4_5_MoePreTrainedModel",
]

View File

@@ -102,7 +102,6 @@ class Ernie4_5IntegrationTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3")
model = Ernie4_5ForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-0.3B-PT",
revision="refs/pr/3",
device_map="auto",
torch_dtype=torch.bfloat16,
)

View File

@@ -18,7 +18,7 @@ import unittest
import pytest
from transformers import Ernie4_5_MoEConfig, is_torch_available
from transformers import Ernie4_5_MoeConfig, is_torch_available
from transformers.testing_utils import (
cleanup,
is_flaky,
@@ -38,33 +38,33 @@ if is_torch_available():
from transformers import (
AutoTokenizer,
Ernie4_5_MoEForCausalLM,
Ernie4_5_MoEModel,
Ernie4_5_MoeForCausalLM,
Ernie4_5_MoeModel,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
class Ernie4_5_MoEModelTester(CausalLMModelTester):
config_class = Ernie4_5_MoEConfig
class Ernie4_5_MoeModelTester(CausalLMModelTester):
config_class = Ernie4_5_MoeConfig
if is_torch_available():
base_model_class = Ernie4_5_MoEModel
causal_lm_class = Ernie4_5_MoEForCausalLM
base_model_class = Ernie4_5_MoeModel
causal_lm_class = Ernie4_5_MoeForCausalLM
@require_torch
class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
class Ernie4_5_MoeModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (
(
Ernie4_5_MoEModel,
Ernie4_5_MoEForCausalLM,
Ernie4_5_MoeModel,
Ernie4_5_MoeForCausalLM,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": Ernie4_5_MoEModel,
"text-generation": Ernie4_5_MoEForCausalLM,
"feature-extraction": Ernie4_5_MoeModel,
"text-generation": Ernie4_5_MoeForCausalLM,
}
if is_torch_available()
else {}
@@ -73,7 +73,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
test_headmasking = False
test_pruning = False
test_all_params_have_gradient = False
model_tester_class = Ernie4_5_MoEModelTester
model_tester_class = Ernie4_5_MoeModelTester
@require_flash_attn
@require_torch_gpu
@@ -82,7 +82,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
if not model_class._supports_flash_attn:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -123,7 +123,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
config.output_router_logits = True
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
model = Ernie4_5_MoEForCausalLM(config)
model = Ernie4_5_MoeForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask)
@@ -153,7 +153,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
@require_torch_multi_accelerator
@require_torch_large_accelerator
@require_torch
class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
class Ernie4_5_MoeIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = None
@@ -169,9 +169,8 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
@classmethod
def get_model(cls):
if cls.model is None:
cls.model = Ernie4_5_MoEForCausalLM.from_pretrained(
cls.model = Ernie4_5_MoeForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-21B-A3B-PT",
revision="refs/pr/11",
device_map="auto",
load_in_4bit=True,
)

View File

@@ -33,7 +33,7 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
"Ernie4_5Config": ["tie_word_embeddings"],
"Ernie4_5_MoEConfig": ["tie_word_embeddings"],
"Ernie4_5_MoeConfig": ["tie_word_embeddings"],
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
# used internally during generation to provide the custom logit processors with their necessary information
"DiaConfig": [