Add Qwen2Moe GGUF loading support (#33264)

* update gguf doc, config and tensor mapping

* add qwen2moe architecture support, GGUFQwen2MoeConverter and q4 unit tests

* apply code style fixes

* reformat files

* assign GGUFQwen2Converter to qwen2_moe
This commit is contained in:
Vladislav Bronzov
2024-09-05 17:42:03 +02:00
committed by GitHub
parent 132e87500e
commit 5d11de4a2f
4 changed files with 76 additions and 5 deletions

View File

@@ -16,7 +16,12 @@ import tempfile
import unittest
from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
from transformers.testing_utils import (
require_gguf,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_torch_available
@@ -33,6 +38,7 @@ class GgufIntegrationTests(unittest.TestCase):
imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
@@ -59,6 +65,7 @@ class GgufIntegrationTests(unittest.TestCase):
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
@@ -298,7 +305,10 @@ class GgufIntegrationTests(unittest.TestCase):
def test_mistral_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16
self.mistral_model_id,
gguf_file=self.q4_0_mistral_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
@@ -310,7 +320,10 @@ class GgufIntegrationTests(unittest.TestCase):
def test_qwen2_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
self.qwen2_model_id,
gguf_file=self.q4_0_qwen2_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
@@ -319,6 +332,21 @@ class GgufIntegrationTests(unittest.TestCase):
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_qwen2_moe_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_moe_model_id, gguf_file=self.q4_0_qwen2_moe_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_moe_model_id,
gguf_file=self.q4_0_qwen2_moe_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_llama3_q4_0_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -331,7 +359,10 @@ class GgufIntegrationTests(unittest.TestCase):
def test_llama3_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16
self.llama3_model_id,
gguf_file=self.q4_llama3_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)