[awq] replace scale when we have GELU (#30074)
* fix awq test * style * add log * new fix * style * only modifying impacted model in the end * rename function
This commit is contained in:
@@ -21,6 +21,7 @@ _import_structure = {
|
||||
"awq": [
|
||||
"fuse_awq_modules",
|
||||
"post_init_awq_exllama_modules",
|
||||
"replace_quantization_scales",
|
||||
"replace_with_awq_linear",
|
||||
],
|
||||
"bitsandbytes": [
|
||||
@@ -92,6 +93,7 @@ if TYPE_CHECKING:
|
||||
from .awq import (
|
||||
fuse_awq_modules,
|
||||
post_init_awq_exllama_modules,
|
||||
replace_quantization_scales,
|
||||
replace_with_awq_linear,
|
||||
)
|
||||
from .bitsandbytes import (
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
"AWQ (Activation aware Weight Quantization) integration file"
|
||||
from ..activations import ACT2FN
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..utils import is_auto_awq_available, is_torch_available
|
||||
from ..utils import is_auto_awq_available, is_torch_available, logging
|
||||
from ..utils.quantization_config import (
|
||||
AwqBackendPackingMethod,
|
||||
AwqConfig,
|
||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
AWQ_FUSED_MAPPINGS = {
|
||||
"mistral": {
|
||||
@@ -56,6 +57,34 @@ AWQ_FUSED_MAPPINGS = {
|
||||
},
|
||||
}
|
||||
|
||||
AWQ_SCALES_MAPPINGS = {
|
||||
"starcoder2": {"act": "act", "layer_before_act": "c_fc"},
|
||||
"RefinedWebModel": {"act": "act", "layer_before_act": "dense_h_to_4h"},
|
||||
"falcon": {"act": "act", "layer_before_act": "dense_h_to_4h"},
|
||||
"mpt": {"act": "act", "layer_before_act": "up_proj"},
|
||||
"gptj": {"act": "act", "layer_before_act": "fc_in"},
|
||||
"gpt_neox": {"act": "act", "layer_before_act": "dense_h_to_4h"},
|
||||
"gpt_bigcode": {"act": "act", "layer_before_act": "c_fc"},
|
||||
"bloom": {"act": "gelu_impl", "layer_before_act": "dense_h_to_4h"},
|
||||
}
|
||||
|
||||
|
||||
def replace_quantization_scales(model, model_type):
|
||||
from awq.modules.act import ScaledActivation
|
||||
|
||||
if model_type not in AWQ_SCALES_MAPPINGS:
|
||||
return model
|
||||
for name, module in model.named_children():
|
||||
act_name = AWQ_SCALES_MAPPINGS[model_type]["act"]
|
||||
layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]
|
||||
if name == act_name and hasattr(model, layer_before_act_name):
|
||||
layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"])
|
||||
size = layer_before_act.out_features
|
||||
scale_like = torch.ones(size)
|
||||
model._modules[name] = ScaledActivation(module, scale_like)
|
||||
_ = replace_quantization_scales(module, model_type)
|
||||
return model
|
||||
|
||||
|
||||
def replace_with_awq_linear(
|
||||
model,
|
||||
|
||||
@@ -75,7 +75,7 @@ class AwqQuantizer(HfQuantizer):
|
||||
return torch_dtype
|
||||
|
||||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_awq_linear
|
||||
from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
@@ -86,6 +86,8 @@ class AwqQuantizer(HfQuantizer):
|
||||
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
|
||||
)
|
||||
|
||||
model = replace_quantization_scales(model, model.config.model_type)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading an AWQ model but no linear modules were found in your model."
|
||||
|
||||
@@ -471,3 +471,22 @@ class AwqFusedTest(unittest.TestCase):
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_auto_awq
|
||||
@require_accelerate
|
||||
class AwqScaleTest(unittest.TestCase):
|
||||
model_name = "TechxGenus/starcoder2-3b-AWQ"
|
||||
|
||||
def test_load_quantized_model(self):
|
||||
from awq.modules.act import ScaledActivation
|
||||
|
||||
"""
|
||||
Simple test that checks if the scales have been replaced in the quantized model
|
||||
"""
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))
|
||||
|
||||
Reference in New Issue
Block a user