fix moe routing_weights (#39581)
* fix moe routing_weights * fix ernie4_5_moe routing_weights * fix integration test --------- Co-authored-by: llbdyiu66 <llbdyiu66@users.noreply.github.com> Co-authored-by: Vasqu <antonprogamer@gmail.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
@@ -339,12 +339,9 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
|
|||||||
# router_logits: (batch * sequence_length, n_experts)
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
router_logits = self.gate(hidden_states.float())
|
router_logits = self.gate(hidden_states.float())
|
||||||
|
|
||||||
# NOTE: we are using the original code base at
|
|
||||||
# https://github.com/PaddlePaddle/Paddle/blob/9b40438ce0f6d76b4f08a7837dd1e28b26cf8ee6/python/paddle/incubate/nn/functional/moe_gate_dispatch.py#L109-L116
|
|
||||||
# this might differ from the remote version regarding the bias (see `Ernie4_5_MoEStatics`)
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
routing_weights = self.moe_statics(routing_weights)
|
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
||||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||||
routing_weights = routing_weights / torch.clamp(
|
routing_weights = routing_weights / torch.clamp(
|
||||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -150,12 +150,9 @@ class Ernie4_5_MoESparseMoeBlock(nn.Module):
|
|||||||
# router_logits: (batch * sequence_length, n_experts)
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
router_logits = self.gate(hidden_states.float())
|
router_logits = self.gate(hidden_states.float())
|
||||||
|
|
||||||
# NOTE: we are using the original code base at
|
|
||||||
# https://github.com/PaddlePaddle/Paddle/blob/9b40438ce0f6d76b4f08a7837dd1e28b26cf8ee6/python/paddle/incubate/nn/functional/moe_gate_dispatch.py#L109-L116
|
|
||||||
# this might differ from the remote version regarding the bias (see `Ernie4_5_MoEStatics`)
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
routing_weights = self.moe_statics(routing_weights)
|
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
||||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||||
routing_weights = routing_weights / torch.clamp(
|
routing_weights = routing_weights / torch.clamp(
|
||||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
|
|||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@slow
|
@slow
|
||||||
def test_model_21b_a3b_generation(self):
|
def test_model_21b_a3b_generation(self):
|
||||||
EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: Yes, I am conscious and I can communicate with you. How can I assist you with any questions or information you need?" # fmt: skip
|
EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: I don't have consciousness in the way humans do. I'm a text-based AI created to process and generate responses based on patterns in data." # fmt: skip
|
||||||
|
|
||||||
model = self.get_model()
|
model = self.get_model()
|
||||||
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT", revision="refs/pr/11")
|
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT", revision="refs/pr/11")
|
||||||
|
|||||||
Reference in New Issue
Block a user