[Ernie 4.5] Post merge adaptations (#39664)

* ernie 4.5 fixes

* Apply style fixes

* fix

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Anton Vlasjuk
2025-07-25 17:36:18 +02:00
committed by GitHub
parent 5d0ba3e479
commit a91653561e
10 changed files with 126 additions and 101 deletions

View File

@@ -102,7 +102,6 @@ class Ernie4_5IntegrationTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3")
model = Ernie4_5ForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-0.3B-PT",
revision="refs/pr/3",
device_map="auto",
torch_dtype=torch.bfloat16,
)

View File

@@ -18,7 +18,7 @@ import unittest
import pytest
from transformers import Ernie4_5_MoEConfig, is_torch_available
from transformers import Ernie4_5_MoeConfig, is_torch_available
from transformers.testing_utils import (
cleanup,
is_flaky,
@@ -38,33 +38,33 @@ if is_torch_available():
from transformers import (
AutoTokenizer,
Ernie4_5_MoEForCausalLM,
Ernie4_5_MoEModel,
Ernie4_5_MoeForCausalLM,
Ernie4_5_MoeModel,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
class Ernie4_5_MoEModelTester(CausalLMModelTester):
config_class = Ernie4_5_MoEConfig
class Ernie4_5_MoeModelTester(CausalLMModelTester):
config_class = Ernie4_5_MoeConfig
if is_torch_available():
base_model_class = Ernie4_5_MoEModel
causal_lm_class = Ernie4_5_MoEForCausalLM
base_model_class = Ernie4_5_MoeModel
causal_lm_class = Ernie4_5_MoeForCausalLM
@require_torch
class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
class Ernie4_5_MoeModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (
(
Ernie4_5_MoEModel,
Ernie4_5_MoEForCausalLM,
Ernie4_5_MoeModel,
Ernie4_5_MoeForCausalLM,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": Ernie4_5_MoEModel,
"text-generation": Ernie4_5_MoEForCausalLM,
"feature-extraction": Ernie4_5_MoeModel,
"text-generation": Ernie4_5_MoeForCausalLM,
}
if is_torch_available()
else {}
@@ -73,7 +73,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
test_headmasking = False
test_pruning = False
test_all_params_have_gradient = False
model_tester_class = Ernie4_5_MoEModelTester
model_tester_class = Ernie4_5_MoeModelTester
@require_flash_attn
@require_torch_gpu
@@ -82,7 +82,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
if not model_class._supports_flash_attn:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -123,7 +123,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
config.output_router_logits = True
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
model = Ernie4_5_MoEForCausalLM(config)
model = Ernie4_5_MoeForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask)
@@ -153,7 +153,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
@require_torch_multi_accelerator
@require_torch_large_accelerator
@require_torch
class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
class Ernie4_5_MoeIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = None
@@ -169,9 +169,8 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
@classmethod
def get_model(cls):
if cls.model is None:
cls.model = Ernie4_5_MoEForCausalLM.from_pretrained(
cls.model = Ernie4_5_MoeForCausalLM.from_pretrained(
"baidu/ERNIE-4.5-21B-A3B-PT",
revision="refs/pr/11",
device_map="auto",
load_in_4bit=True,
)