Enables CPU AWQ model with IPEX version. (#33460)

* enable cpu awq ipex linear

* add doc for cpu awq with ipex kernel

* add tests for cpu awq

* fix code style

* fix doc and tests

* Update docs/source/en/quantization/awq.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/autoawq/test_awq.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* fix comments

* fix log

* fix log

* fix style

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jiqing-feng
2024-10-04 22:25:10 +08:00
committed by GitHub
parent de4112e4d2
commit b916efcb3c
6 changed files with 138 additions and 20 deletions

View File

@@ -21,6 +21,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqCon
from transformers.testing_utils import (
require_accelerate,
require_auto_awq,
require_intel_extension_for_pytorch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
@@ -490,3 +491,31 @@ class AwqScaleTest(unittest.TestCase):
"TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda"
)
self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))
@slow
@require_auto_awq
@require_accelerate
@require_intel_extension_for_pytorch
class AwqIPEXTest(unittest.TestCase):
def test_quantized_model_ipex(self):
"""
Simple test that checks if the quantized model is working properly with ipex backend
"""
quantization_config = AwqConfig(version="ipex")
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization_config=quantization_config,
device_map="cpu",
)
tokenizer = AutoTokenizer.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ")
input_ids = tokenizer.encode("How to make a cake", return_tensors="pt")
pad_token_id = tokenizer.eos_token_id
output = model.generate(input_ids, do_sample=False, max_length=20, pad_token_id=pad_token_id)
print(tokenizer.decode(output[0], skip_special_tokens=True))
expected_output = (
"How to make a cake with a round tin?\nHow to make a cake with a round tin?\n1. Preheat the oven to 180°"
)
self.assertIn(tokenizer.decode(output[0], skip_special_tokens=True), expected_output)