From de6e0db184d565847356a6a08dde2f043e744c72 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Mon, 13 May 2024 11:41:03 +0200 Subject: [PATCH] [awq] replace scale when we have GELU (#30074) * fix awq test * style * add log * new fix * style * only modifying impacted model in the end * rename function --- src/transformers/integrations/__init__.py | 2 ++ src/transformers/integrations/awq.py | 31 +++++++++++++++++++- src/transformers/quantizers/quantizer_awq.py | 4 ++- tests/quantization/autoawq/test_awq.py | 19 ++++++++++++ 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 69fb0e3259..19a3f421ca 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -21,6 +21,7 @@ _import_structure = { "awq": [ "fuse_awq_modules", "post_init_awq_exllama_modules", + "replace_quantization_scales", "replace_with_awq_linear", ], "bitsandbytes": [ @@ -92,6 +93,7 @@ if TYPE_CHECKING: from .awq import ( fuse_awq_modules, post_init_awq_exllama_modules, + replace_quantization_scales, replace_with_awq_linear, ) from .bitsandbytes import ( diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index a543860f10..a83b27e95a 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -14,7 +14,7 @@ "AWQ (Activation aware Weight Quantization) integration file" from ..activations import ACT2FN from ..modeling_utils import PreTrainedModel -from ..utils import is_auto_awq_available, is_torch_available +from ..utils import is_auto_awq_available, is_torch_available, logging from ..utils.quantization_config import ( AwqBackendPackingMethod, AwqConfig, @@ -27,6 +27,7 @@ if is_torch_available(): import torch import torch.nn as nn +logger = logging.get_logger(__name__) AWQ_FUSED_MAPPINGS = { "mistral": { @@ -56,6 +57,34 @@ AWQ_FUSED_MAPPINGS = { }, } +AWQ_SCALES_MAPPINGS = { + "starcoder2": {"act": "act", "layer_before_act": "c_fc"}, + "RefinedWebModel": {"act": "act", "layer_before_act": "dense_h_to_4h"}, + "falcon": {"act": "act", "layer_before_act": "dense_h_to_4h"}, + "mpt": {"act": "act", "layer_before_act": "up_proj"}, + "gptj": {"act": "act", "layer_before_act": "fc_in"}, + "gpt_neox": {"act": "act", "layer_before_act": "dense_h_to_4h"}, + "gpt_bigcode": {"act": "act", "layer_before_act": "c_fc"}, + "bloom": {"act": "gelu_impl", "layer_before_act": "dense_h_to_4h"}, +} + + +def replace_quantization_scales(model, model_type): + from awq.modules.act import ScaledActivation + + if model_type not in AWQ_SCALES_MAPPINGS: + return model + for name, module in model.named_children(): + act_name = AWQ_SCALES_MAPPINGS[model_type]["act"] + layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"] + if name == act_name and hasattr(model, layer_before_act_name): + layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]) + size = layer_before_act.out_features + scale_like = torch.ones(size) + model._modules[name] = ScaledActivation(module, scale_like) + _ = replace_quantization_scales(module, model_type) + return model + def replace_with_awq_linear( model, diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 5e66f9baf1..f9e4444f07 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -75,7 +75,7 @@ class AwqQuantizer(HfQuantizer): return torch_dtype def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): - from ..integrations import get_keys_to_not_convert, replace_with_awq_linear + from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear self.modules_to_not_convert = get_keys_to_not_convert(model) @@ -86,6 +86,8 @@ class AwqQuantizer(HfQuantizer): model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert ) + model = replace_quantization_scales(model, model.config.model_type) + if not has_been_replaced: logger.warning( "You are loading an AWQ model but no linear modules were found in your model." diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index e2369f07b2..20ecd783cf 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -471,3 +471,22 @@ class AwqFusedTest(unittest.TestCase): outputs = model.generate(**inputs, max_new_tokens=12) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL) + + +@slow +@require_torch_gpu +@require_auto_awq +@require_accelerate +class AwqScaleTest(unittest.TestCase): + model_name = "TechxGenus/starcoder2-3b-AWQ" + + def test_load_quantized_model(self): + from awq.modules.act import ScaledActivation + + """ + Simple test that checks if the scales have been replaced in the quantized model + """ + quantized_model = AutoModelForCausalLM.from_pretrained( + "TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda" + ) + self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))