From 4fc708f98c9c8d5cb48e8a2639e3f7a21c65802f Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Tue, 5 Mar 2024 03:22:48 +0100 Subject: [PATCH] Exllama kernels support for AWQ models (#28634) * added exllama kernels support for awq models * doc * style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * refactor * moved exllama post init to after device dispatching * bump autoawq version * added exllama test * style * configurable exllama kernels * copy exllama_config from gptq * moved exllama version check to post init * moved to quantization dockerfile --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- .../Dockerfile | 2 +- src/transformers/integrations/__init__.py | 12 +++- src/transformers/integrations/awq.py | 59 +++++++++++++++++-- src/transformers/quantizers/quantizer_awq.py | 11 ++++ src/transformers/utils/quantization_config.py | 42 +++++++++++-- tests/quantization/autoawq/test_awq.py | 14 +++++ 6 files changed, 127 insertions(+), 13 deletions(-) diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index 66bdcc42ba..1eadb1505b 100644 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -43,7 +43,7 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2 # Add autoawq for quantization testing -RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl +RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index bded6b3984..200607b0d5 100644 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -18,7 +18,11 @@ from ..utils import _LazyModule _import_structure = { "aqlm": ["replace_with_aqlm_linear"], - "awq": ["fuse_awq_modules", "replace_with_awq_linear"], + "awq": [ + "fuse_awq_modules", + "post_init_awq_exllama_modules", + "replace_with_awq_linear", + ], "bitsandbytes": [ "get_keys_to_not_convert", "replace_8bit_linear", @@ -82,7 +86,11 @@ _import_structure = { if TYPE_CHECKING: from .aqlm import replace_with_aqlm_linear - from .awq import fuse_awq_modules, replace_with_awq_linear + from .awq import ( + fuse_awq_modules, + post_init_awq_exllama_modules, + replace_with_awq_linear, + ) from .bitsandbytes import ( get_keys_to_not_convert, replace_8bit_linear, diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index dd8578ef60..3f9f0d1d21 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -15,7 +15,12 @@ from ..activations import ACT2FN from ..modeling_utils import PreTrainedModel from ..utils import is_auto_awq_available, is_torch_available -from ..utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion +from ..utils.quantization_config import ( + AwqBackendPackingMethod, + AwqConfig, + AWQLinearVersion, + ExllamaVersion, +) if is_torch_available(): @@ -91,13 +96,30 @@ def replace_with_awq_linear( ) if backend == AwqBackendPackingMethod.AUTOAWQ: - from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV - elif backend == AwqBackendPackingMethod.LLMAWQ: + if quantization_config.version == AWQLinearVersion.GEMM: + from awq.modules.linear.gemm import WQLinear_GEMM + + target_cls = WQLinear_GEMM + elif quantization_config.version == AWQLinearVersion.GEMV: + from awq.modules.linear.gemv import WQLinear_GEMV + + target_cls = WQLinear_GEMV + elif quantization_config.version == AWQLinearVersion.EXLLAMA: + if quantization_config.exllama_config["version"] == ExllamaVersion.ONE: + from awq.modules.linear.exllama import WQLinear_Exllama + + target_cls = WQLinear_Exllama + elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO: + from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2 + + target_cls = WQLinear_ExllamaV2 + else: + raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}") + else: + raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}") + else: from awq.quantize.qmodule import WQLinear - if backend == AwqBackendPackingMethod.AUTOAWQ: - target_cls = WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV - else: target_cls = WQLinear for name, module in model.named_children(): @@ -372,3 +394,28 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na setattr(parent, child_name, fused_attention_layer.to(previous_device)) del q_proj, k_proj, v_proj, o_proj + + +def post_init_awq_exllama_modules(model, exllama_config): + """ + Runs post init for Exllama layers which performs: + - Weights unpacking, reordering and repacking + - Devices scratch space allocation + """ + + if exllama_config["version"] == ExllamaVersion.ONE: + from awq.modules.linear.exllama import exllama_post_init + + model = exllama_post_init(model) + elif exllama_config["version"] == ExllamaVersion.TWO: + from awq.modules.linear.exllamav2 import exllamav2_post_init + + model = exllamav2_post_init( + model, + max_input_len=exllama_config["max_input_len"], + max_batch_size=exllama_config["max_batch_size"], + ) + else: + raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}") + + return model diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 08342df175..5e66f9baf1 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging +from ..utils.quantization_config import AWQLinearVersion if is_torch_available(): @@ -98,12 +99,22 @@ class AwqQuantizer(HfQuantizer): model = fuse_awq_modules(model, self.quantization_config) model._awq_is_fused = True # TODO: consider storing this flag in model.config instead + if self.quantization_config.version == AWQLinearVersion.EXLLAMA: + from ..integrations import post_init_awq_exllama_modules + + model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config) + @property def is_serializable(self): # AWQ through auto-awq has been always serializable, except if the model is fused. if self.quantization_config.do_fuse: logger.warning("You cannot save an AWQ model that uses fused modules!") return False + + if self.quantization_config.version == AWQLinearVersion.EXLLAMA: + logger.warning("You cannot save an AWQ model that uses Exllama backend!") + return False + return True @property diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index bcf31ebfab..a29886d8c6 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -44,6 +44,7 @@ class QuantizationMethod(str, Enum): class AWQLinearVersion(str, Enum): GEMM = "gemm" GEMV = "gemv" + EXLLAMA = "exllama" @staticmethod def from_str(version: str): @@ -52,6 +53,8 @@ class AWQLinearVersion(str, Enum): return AWQLinearVersion.GEMM elif version == "gemv": return AWQLinearVersion.GEMV + elif version == "exllama": + return AWQLinearVersion.EXLLAMA else: raise ValueError(f"Unknown AWQLinearVersion {version}") @@ -606,7 +609,7 @@ class AwqConfig(QuantizationConfigMixin): Whether to use zero point quantization. version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`): The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise, - GEMV is better (e.g. < 8 ) + GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels. backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`): The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users that quantize their own models using `llm-awq` library. @@ -620,6 +623,10 @@ class AwqConfig(QuantizationConfigMixin): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models. + exllama_config (`Dict[str, Any]`, *optional*): + You can specify the version of the exllama kernel through the `version` key, the maximum sequence + length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key. + Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset. """ def __init__( @@ -633,6 +640,7 @@ class AwqConfig(QuantizationConfigMixin): fuse_max_seq_len: Optional[int] = None, modules_to_fuse: Optional[dict] = None, modules_to_not_convert: Optional[List] = None, + exllama_config: Optional[Dict[str, int]] = None, **kwargs, ): self.quant_method = QuantizationMethod.AWQ @@ -644,6 +652,7 @@ class AwqConfig(QuantizationConfigMixin): self.backend = backend self.fuse_max_seq_len = fuse_max_seq_len self.modules_to_not_convert = modules_to_not_convert + self.exllama_config = exllama_config self.modules_to_fuse = modules_to_fuse if do_fuse is None: @@ -667,9 +676,9 @@ class AwqConfig(QuantizationConfigMixin): ) self.version = AWQLinearVersion.from_str(self.version) - if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]: + if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]: raise ValueError( - f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}" + f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}" ) if self.backend == AwqBackendPackingMethod.LLMAWQ: @@ -724,9 +733,34 @@ class AwqConfig(QuantizationConfigMixin): f"Required fields are missing in the fusing mapping, required fields are {required_keys}" ) + if self.version == AWQLinearVersion.EXLLAMA: + awq_version_supports_exllama = False + MIN_AWQ_VERSION = "0.2.0" + if is_auto_awq_available(): + awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse( + MIN_AWQ_VERSION + ) + + if not awq_version_supports_exllama: + raise ValueError( + f"You current version of `autoawq` does not support exllama backend, " + f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}." + ) + + if self.exllama_config is None: + self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8} + else: + if "version" not in self.exllama_config: + raise ValueError("`exllama_config` needs to have a `version` key.") + elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]: + exllama_version = self.exllama_config["version"] + raise ValueError( + f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}" + ) + def get_loading_attributes(self): attibutes_dict = copy.deepcopy(self.__dict__) - loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"] + loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len"] loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} return loading_attibutes_dict diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index a2dbd904a5..8ed8c394f4 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -192,6 +192,20 @@ class AwqTest(unittest.TestCase): output = quantized_model.generate(**input_ids, max_new_tokens=40) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16) + def test_quantized_model_exllama(self): + """ + Simple test that checks if the quantized model is working properly with exllama backend + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + quantization_config = AwqConfig(version="exllama") + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, quantization_config=quantization_config + ).to(torch_device) + + output = quantized_model.generate(**input_ids, max_new_tokens=40) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + def test_quantized_model_no_device_map(self): """ Simple test that checks if the quantized model is working properly