[Ernie 4.5] Post merge adaptations (#39664)
* ernie 4.5 fixes * Apply style fixes * fix --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -102,7 +102,6 @@ class Ernie4_5IntegrationTest(unittest.TestCase):
|
||||
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-0.3B-PT", revision="refs/pr/3")
|
||||
model = Ernie4_5ForCausalLM.from_pretrained(
|
||||
"baidu/ERNIE-4.5-0.3B-PT",
|
||||
revision="refs/pr/3",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import Ernie4_5_MoEConfig, is_torch_available
|
||||
from transformers import Ernie4_5_MoeConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
is_flaky,
|
||||
@@ -38,33 +38,33 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
Ernie4_5_MoEForCausalLM,
|
||||
Ernie4_5_MoEModel,
|
||||
Ernie4_5_MoeForCausalLM,
|
||||
Ernie4_5_MoeModel,
|
||||
)
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
|
||||
|
||||
class Ernie4_5_MoEModelTester(CausalLMModelTester):
|
||||
config_class = Ernie4_5_MoEConfig
|
||||
class Ernie4_5_MoeModelTester(CausalLMModelTester):
|
||||
config_class = Ernie4_5_MoeConfig
|
||||
if is_torch_available():
|
||||
base_model_class = Ernie4_5_MoEModel
|
||||
causal_lm_class = Ernie4_5_MoEForCausalLM
|
||||
base_model_class = Ernie4_5_MoeModel
|
||||
causal_lm_class = Ernie4_5_MoeForCausalLM
|
||||
|
||||
|
||||
@require_torch
|
||||
class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
class Ernie4_5_MoeModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
Ernie4_5_MoEModel,
|
||||
Ernie4_5_MoEForCausalLM,
|
||||
Ernie4_5_MoeModel,
|
||||
Ernie4_5_MoeForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Ernie4_5_MoEModel,
|
||||
"text-generation": Ernie4_5_MoEForCausalLM,
|
||||
"feature-extraction": Ernie4_5_MoeModel,
|
||||
"text-generation": Ernie4_5_MoeForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@@ -73,7 +73,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
test_all_params_have_gradient = False
|
||||
model_tester_class = Ernie4_5_MoEModelTester
|
||||
model_tester_class = Ernie4_5_MoeModelTester
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@@ -82,7 +82,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
@slow
|
||||
def test_flash_attn_2_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -123,7 +123,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
config.output_router_logits = True
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
model = Ernie4_5_MoEForCausalLM(config)
|
||||
model = Ernie4_5_MoeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask)
|
||||
@@ -153,7 +153,7 @@ class Ernie4_5_MoEModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
@require_torch_multi_accelerator
|
||||
@require_torch_large_accelerator
|
||||
@require_torch
|
||||
class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
|
||||
class Ernie4_5_MoeIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = None
|
||||
@@ -169,9 +169,8 @@ class Ernie4_5_MoEIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def get_model(cls):
|
||||
if cls.model is None:
|
||||
cls.model = Ernie4_5_MoEForCausalLM.from_pretrained(
|
||||
cls.model = Ernie4_5_MoeForCausalLM.from_pretrained(
|
||||
"baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
revision="refs/pr/11",
|
||||
device_map="auto",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user