Enables CPU AWQ model with IPEX version. (#33460)

* enable cpu awq ipex linear

* add doc for cpu awq with ipex kernel

* add tests for cpu awq

* fix code style

* fix doc and tests

* Update docs/source/en/quantization/awq.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/autoawq/test_awq.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* fix comments

* fix log

* fix log

* fix style

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jiqing-feng
2024-10-04 22:25:10 +08:00
committed by GitHub
parent de4112e4d2
commit b916efcb3c
6 changed files with 138 additions and 20 deletions

View File

@@ -230,3 +230,44 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
Note this feature is supported on AMD GPUs.
</Tip>
## CPU support
Recent versions of `autoawq` supports CPU with ipex op optimizations. To get started, first install the latest version of `autoawq` by running:
```bash
pip install intel-extension-for-pytorch
pip install git+https://github.com/casper-hansen/AutoAWQ.git
```
Get started by passing an `AwqConfig()` with `version="ipex"`.
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig
quantization_config = AwqConfig(version="ipex")
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization_config=quantization_config,
device_map="cpu",
)
input_ids = torch.randint(0, 100, (1, 128), dtype=torch.long, device="cpu")
output = model(input_ids)
print(output.logits)
tokenizer = AutoTokenizer.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ")
input_ids = tokenizer.encode("How to make a cake", return_tensors="pt")
pad_token_id = tokenizer.eos_token_id
output = model.generate(input_ids, do_sample=True, max_length=50, pad_token_id=pad_token_id)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
<Tip warning={true}>
Note this feature is supported on Intel CPUs.
</Tip>

View File

@@ -21,6 +21,7 @@ _import_structure = {
"awq": [
"fuse_awq_modules",
"post_init_awq_exllama_modules",
"post_init_awq_ipex_modules",
"replace_quantization_scales",
"replace_with_awq_linear",
],
@@ -115,6 +116,7 @@ if TYPE_CHECKING:
from .awq import (
fuse_awq_modules,
post_init_awq_exllama_modules,
post_init_awq_ipex_modules,
replace_quantization_scales,
replace_with_awq_linear,
)

View File

@@ -145,6 +145,10 @@ def replace_with_awq_linear(
target_cls = WQLinear_ExllamaV2
else:
raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}")
elif quantization_config.version == AWQLinearVersion.IPEX:
from awq.modules.linear.gemm_ipex import WQLinear_IPEX
target_cls = WQLinear_IPEX
else:
raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}")
else:
@@ -266,8 +270,11 @@ def fuse_awq_modules(model, quantization_config):
# Replace layer norms
_fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)
# Replace MLP layers
_fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)
# Replace MLP layers if awq version is not ipex.
if quantization_config.version != "ipex":
_fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)
else:
logger.info("The IPEX version AWQ does not support fuse mlp for now.")
# Replace attention layers
attention_has_been_fused = _fuse_awq_attention_layers(
@@ -372,7 +379,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
The `QuantAttentionFused` class as it only supports that class
for now.
"""
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV, WQLinear_IPEX
module_has_been_fused = False
@@ -389,6 +396,9 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
elif isinstance(q_proj, WQLinear_GEMM):
linear_target_cls = WQLinear_GEMM
cat_dim = 1
elif isinstance(q_proj, WQLinear_IPEX):
linear_target_cls = WQLinear_IPEX
cat_dim = 1
else:
raise ValueError("Unsupported q_proj type: {type(q_proj)}")
@@ -466,3 +476,16 @@ def post_init_awq_exllama_modules(model, exllama_config):
raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")
return model
def post_init_awq_ipex_modules(model):
"""
Runs post init for IPEX layers which performs:
- Weights packing, reordering and repacking
"""
from awq.modules.linear.gemm_ipex import ipex_post_init
model = ipex_post_init(model)
return model

View File

@@ -46,26 +46,39 @@ class AwqQuantizer(HfQuantizer):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run AWQ quantized model.")
if not is_auto_awq_available():
raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")
if not is_accelerate_available():
raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
if device_map is None:
logger.warning_once(
"You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
elif device_map is not None:
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
if self.quantization_config.version == AWQLinearVersion.IPEX:
if (
device_map is not None
and isinstance(device_map, dict)
and (torch.device("cpu") not in device_map.values() or len(device_map.values()) > 1)
):
raise ValueError(
"You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
"You are attempting to load an IPEX version AWQ model with a device_map that contains more than CPU."
" This is not supported. Please make sure only cpu in the device_map."
)
else:
if not torch.cuda.is_available():
raise RuntimeError(
"GPU is required to run AWQ quantized model. You can use IPEX version AWQ if you have an Intel CPU"
)
if device_map is None:
logger.warning_once(
"You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
elif device_map is not None:
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype):
if torch_dtype is None:
@@ -106,6 +119,11 @@ class AwqQuantizer(HfQuantizer):
model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)
if self.quantization_config.version == AWQLinearVersion.IPEX:
from ..integrations import post_init_awq_ipex_modules
model = post_init_awq_ipex_modules(model)
def is_serializable(self, safe_serialization=None):
# AWQ through auto-awq has been always serializable, except if the model is fused.
if self.quantization_config.do_fuse:

View File

@@ -51,6 +51,7 @@ class AWQLinearVersion(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
EXLLAMA = "exllama"
IPEX = "ipex"
@staticmethod
def from_str(version: str):
@@ -61,6 +62,8 @@ class AWQLinearVersion(str, Enum):
return AWQLinearVersion.GEMV
elif version == "exllama":
return AWQLinearVersion.EXLLAMA
elif version == "ipex":
return AWQLinearVersion.IPEX
else:
raise ValueError(f"Unknown AWQLinearVersion {version}")
@@ -830,18 +833,20 @@ class AwqConfig(QuantizationConfigMixin):
r"""
Safety checker that arguments are correct
"""
if not torch.cuda.is_available():
raise ValueError("AWQ is only available on GPU")
if self.backend not in [AwqBackendPackingMethod.AUTOAWQ, AwqBackendPackingMethod.LLMAWQ]:
raise ValueError(
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
)
self.version = AWQLinearVersion.from_str(self.version)
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]:
if self.version not in [
AWQLinearVersion.GEMM,
AWQLinearVersion.GEMV,
AWQLinearVersion.EXLLAMA,
AWQLinearVersion.IPEX,
]:
raise ValueError(
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}"
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA, AWQLinearVersion.IPEX] - not recognized version {self.version}"
)
if self.backend == AwqBackendPackingMethod.LLMAWQ:

View File

@@ -21,6 +21,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqCon
from transformers.testing_utils import (
require_accelerate,
require_auto_awq,
require_intel_extension_for_pytorch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
@@ -490,3 +491,31 @@ class AwqScaleTest(unittest.TestCase):
"TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda"
)
self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))
@slow
@require_auto_awq
@require_accelerate
@require_intel_extension_for_pytorch
class AwqIPEXTest(unittest.TestCase):
def test_quantized_model_ipex(self):
"""
Simple test that checks if the quantized model is working properly with ipex backend
"""
quantization_config = AwqConfig(version="ipex")
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization_config=quantization_config,
device_map="cpu",
)
tokenizer = AutoTokenizer.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ")
input_ids = tokenizer.encode("How to make a cake", return_tensors="pt")
pad_token_id = tokenizer.eos_token_id
output = model.generate(input_ids, do_sample=False, max_length=20, pad_token_id=pad_token_id)
print(tokenizer.decode(output[0], skip_special_tokens=True))
expected_output = (
"How to make a cake with a round tin?\nHow to make a cake with a round tin?\n1. Preheat the oven to 180°"
)
self.assertIn(tokenizer.decode(output[0], skip_special_tokens=True), expected_output)