From a62f65a989b10bf1130e098bb62f16a5b3994ee8 Mon Sep 17 00:00:00 2001 From: llbdyiu66 <125861386+llbdyiu66@users.noreply.github.com> Date: Wed, 23 Jul 2025 19:20:23 +0800 Subject: [PATCH] fix moe routing_weights (#39581) * fix moe routing_weights * fix ernie4_5_moe routing_weights * fix integration test --------- Co-authored-by: llbdyiu66 Co-authored-by: Vasqu Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --- .../models/ernie4_5_moe/modeling_ernie4_5_moe.py | 7 ++----- .../models/ernie4_5_moe/modular_ernie4_5_moe.py | 7 ++----- tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py | 2 +- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 74671bb33f..14e598bff9 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -339,12 +339,9 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module): # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states.float()) - # NOTE: we are using the original code base at - # https://github.com/PaddlePaddle/Paddle/blob/9b40438ce0f6d76b4f08a7837dd1e28b26cf8ee6/python/paddle/incubate/nn/functional/moe_gate_dispatch.py#L109-L116 - # this might differ from the remote version regarding the bias (see `Ernie4_5_MoEStatics`) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights = self.moe_statics(routing_weights) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) + routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) routing_weights = routing_weights / torch.clamp( routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index 0763e415c5..3c4e068d37 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -150,12 +150,9 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module): # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states.float()) - # NOTE: we are using the original code base at - # https://github.com/PaddlePaddle/Paddle/blob/9b40438ce0f6d76b4f08a7837dd1e28b26cf8ee6/python/paddle/incubate/nn/functional/moe_gate_dispatch.py#L109-L116 - # this might differ from the remote version regarding the bias (see `Ernie4_5_MoEStatics`) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights = self.moe_statics(routing_weights) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) + routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) routing_weights = routing_weights / torch.clamp( routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) diff --git a/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py b/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py index 63fb00745c..b8a8130155 100644 --- a/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py +++ b/tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py @@ -181,7 +181,7 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase): @require_bitsandbytes @slow def test_model_21b_a3b_generation(self): - EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: Yes, I am conscious and I can communicate with you. How can I assist you with any questions or information you need?" # fmt: skip + EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: I don't have consciousness in the way humans do. I'm a text-based AI created to process and generate responses based on patterns in data." # fmt: skip model = self.get_model() tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT", revision="refs/pr/11")