[Ernie 4.5] Add ernie text models (#39228)
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled

* init

* copied from remote

* add proper structure and llama like structure

* fixup

* revert to state that works

* get closer to llama

* slow and steady

* some removal

* masks work

* it is indeed the rope implementation, how dafuq does it mesh with the cache now hmm

* nice

* getting closer

* closer to transformers style

* let's simplify this, batching works now

* simplified

* working version with modular

* it is indeed the rotation per weights, make it complete llama style

* cleanup conversion, next to look at -> tokenizer

* remove llama artefacts

* fix modeling tests (common ones)

* style

* integration test + first look into tokenization (will need more work, focussing on modeling other models first)

* style

* working moe version, based on remote

* lets keep it simple and go step by step - transformers annotations for modular and transformers style rope (complex view)

* more cleanup

* refactor namings and remove addition forXXX classes

* our moe won't cut it it seems, correction bias seems to be missing in remote code version

* tokenization change (remote)

* our moe version works when adding normalization :D

* cleanup moe

* nits

* cleanup modeling -> let's get to modular next

* style

* modular v1

* minor things + attempt at conversion (which doesn't work)

* no conversion follow glm, fixup modular and other nits

* modular cleanup

* fixes

* tests, tests, tests + some moe dtype forcing

* simplify modular, fix fatal fa2 bug, remaining tests

* fix import issue?

* some initial docs, fix bnb faulty behavior --> needs to fix some tests because of gate needing to be float

* fix sdpa test, load on init dtype only

* fixup post merge

* style

* fix doc links

* tokenization cleanup beginnings

* simplify tokenizer by a lot as its basically llama

* tokenizer is full llama with different defaults + extra special tokens

* sync og special tokens of ernie

* fix decoding with numbers (also in remote done what a timing), begin of tok tests

* align with remote and preserve special tokens, adjust tests to ernie legacy behavior, warning for questionable behavior (also in llama)

* nits

* docs

* my daily post merge it is

* check

* tokenization update with explanations and conversion script

* review on modular (til), revert some tokenizer things i did prior, remove mtp comment (low prio)

* post merge fixes

* fixup tokenization, llama fast is the way to go

* more fixups

* check

* import fixes

* correction bias following the paddle code

* fix

* fix TP plan, fix correction bias sharding during forward

* style

* whoops

* fix tied weights

* docs and last nit

* license

* flasky tests

* move repo id, update when merged on the hub
This commit is contained in:
Anton Vlasjuk
2025-07-21 19:51:49 +02:00
committed by GitHub
parent 69b158260f
commit b4115a426e
23 changed files with 2956 additions and 2 deletions

View File

@@ -104,9 +104,11 @@ class CausalLMModelTester:
is_decoder=False,
scope=None,
expert_interval=1,
moe_layer_start_index=0,
moe_intermediate_size=12,
shared_expert_intermediate_size=36,
shared_expert_gate=True,
moe_num_shared_experts=2,
num_experts_per_tok=2,
num_experts=8,
mamba_n_groups=1,
@@ -146,9 +148,11 @@ class CausalLMModelTester:
self.head_dim = self.hidden_size // self.num_attention_heads
self.is_decoder = is_decoder
self.expert_interval = expert_interval
self.moe_layer_start_index = moe_layer_start_index
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.shared_expert_gate = shared_expert_gate
self.moe_num_shared_experts = moe_num_shared_experts
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.mamba_n_groups = mamba_n_groups

View File

View File

@@ -0,0 +1,122 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Ernie4.5 model."""
import unittest
from transformers import is_torch_available
from transformers.testing_utils import (
Expectations,
cleanup,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
Ernie4_5Config,
Ernie4_5ForCausalLM,
Ernie4_5Model,
)
from transformers.models.ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding
class Ernie4_5ModelTester(CausalLMModelTester):
if is_torch_available():
config_class = Ernie4_5Config
base_model_class = Ernie4_5Model
causal_lm_class = Ernie4_5ForCausalLM
@require_torch
class Ernie4_5ModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (
(
Ernie4_5Model,
Ernie4_5ForCausalLM,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": Ernie4_5Model,
"text-generation": Ernie4_5ForCausalLM,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
model_tester_class = Ernie4_5ModelTester
rotary_embedding_layer = Ernie4_5RotaryEmbedding # Enables RoPE tests if set
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = Ernie4_5ForCausalLM if is_torch_available() else None
@require_torch_accelerator
class Ernie4_5IntegrationTest(unittest.TestCase):
def setup(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@slow
def test_ernie4_5_0p3B(self):
"""
An integration test for Ernie 4.5 0.3B.
"""
expected_texts = Expectations(
{
("cuda", None): "User: Hey, are you conscious? Can you talk to me?\nAssistant: Hey! I'm here to help you with whatever you need. Are you feeling a bit overwhelmed or stressed? I'm here to listen and provide support.",
}
) # fmt: skip
EXPECTED_TEXT = expected_texts.get_expectation()
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3")
model = Ernie4_5ForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-0.3B-PT",
revision="refs/pr/3",
device_map="auto",
torch_dtype=torch.bfloat16,
)
prompt = "Hey, are you conscious? Can you talk to me?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=128,
do_sample=False,
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip("\n")
self.assertEqual(generated_text, EXPECTED_TEXT)

View File

View File

@@ -0,0 +1,199 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Ernie4.5 MoE model."""
import tempfile
import unittest
import pytest
from transformers import Ernie4_5_MoEConfig, is_torch_available
from transformers.testing_utils import (
cleanup,
is_flaky,
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_large_accelerator,
require_torch_multi_accelerator,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
Ernie4_5_MoEForCausalLM,
Ernie4_5_MoEModel,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
class Ernie4_5_MoEModelTester(CausalLMModelTester):
config_class = Ernie4_5_MoEConfig
if is_torch_available():
base_model_class = Ernie4_5_MoEModel
causal_lm_class = Ernie4_5_MoEForCausalLM
@require_torch
class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (
(
Ernie4_5_MoEModel,
Ernie4_5_MoEForCausalLM,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": Ernie4_5_MoEModel,
"text-generation": Ernie4_5_MoEForCausalLM,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
test_all_params_have_gradient = False
model_tester_class = Ernie4_5_MoEModelTester
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@is_flaky()
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
)
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
# higher tolerance, not sure where it stems from
assert torch.allclose(logits_fa, logits, atol=1e-2, rtol=1e-2)
# Ignore copy
def test_load_balancing_loss(self):
r"""
Let's make sure we can actually compute the loss and do a backward on it.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.num_experts = 8
config.expert_interval = 2
config.output_router_logits = True
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
model = Ernie4_5_MoEForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask)
self.assertEqual(result.router_logits[0].shape, (91, config.num_experts))
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)
# First, we make sure that adding padding tokens doesn't change the loss
# loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding)
pad_length = 1000
# Add padding tokens (assume that pad_token_id=1) to input_ids
padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device)
padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left
padded_attention_mask = padded_input_ids.ne(1).to(torch_device)
padded_result = model(padded_input_ids, attention_mask=padded_attention_mask)
torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4)
# We make sure that the loss of including padding tokens != the loss without padding tokens
# if attention_mask=None --> we don't exclude padding tokens
include_padding_result = model(padded_input_ids, attention_mask=None)
# This is to mimic torch.testing.assert_not_close
self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item())
# Run on runners with larger accelerators (for example A10 instead of T4) with a lot of CPU RAM (e.g. g5-12xlarge)
@require_torch_multi_accelerator
@require_torch_large_accelerator
@require_torch
class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = None
@classmethod
def tearDownClass(cls):
del cls.model
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@classmethod
def get_model(cls):
if cls.model is None:
cls.model = Ernie4_5_MoEForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-21B-A3B-PT",
revision="refs/pr/11",
device_map="auto",
load_in_4bit=True,
)
return cls.model
@require_bitsandbytes
@slow
def test_model_21b_a3b_generation(self):
EXPECTED_TEXT_COMPLETION = "User: Hey, are you conscious? Can you talk to me?\nAssistant: Yes, I am conscious and I can communicate with you. How can I assist you with any questions or information you need?" # fmt: skip
model = self.get_model()
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-21B-A3B-PT", revision="refs/pr/11")
prompt = "Hey, are you conscious? Can you talk to me?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=32,
do_sample=False,
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip("\n")
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)

View File

@@ -258,10 +258,10 @@ def _test_eager_matches_sdpa_inference(
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="sdpa")
except ValueError:
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_eager.eval().to(torch_device)
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)