[awq] replace scale when we have GELU (#30074)
* fix awq test * style * add log * new fix * style * only modifying impacted model in the end * rename function
This commit is contained in:
@@ -21,6 +21,7 @@ _import_structure = {
|
|||||||
"awq": [
|
"awq": [
|
||||||
"fuse_awq_modules",
|
"fuse_awq_modules",
|
||||||
"post_init_awq_exllama_modules",
|
"post_init_awq_exllama_modules",
|
||||||
|
"replace_quantization_scales",
|
||||||
"replace_with_awq_linear",
|
"replace_with_awq_linear",
|
||||||
],
|
],
|
||||||
"bitsandbytes": [
|
"bitsandbytes": [
|
||||||
@@ -92,6 +93,7 @@ if TYPE_CHECKING:
|
|||||||
from .awq import (
|
from .awq import (
|
||||||
fuse_awq_modules,
|
fuse_awq_modules,
|
||||||
post_init_awq_exllama_modules,
|
post_init_awq_exllama_modules,
|
||||||
|
replace_quantization_scales,
|
||||||
replace_with_awq_linear,
|
replace_with_awq_linear,
|
||||||
)
|
)
|
||||||
from .bitsandbytes import (
|
from .bitsandbytes import (
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
"AWQ (Activation aware Weight Quantization) integration file"
|
"AWQ (Activation aware Weight Quantization) integration file"
|
||||||
from ..activations import ACT2FN
|
from ..activations import ACT2FN
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import PreTrainedModel
|
||||||
from ..utils import is_auto_awq_available, is_torch_available
|
from ..utils import is_auto_awq_available, is_torch_available, logging
|
||||||
from ..utils.quantization_config import (
|
from ..utils.quantization_config import (
|
||||||
AwqBackendPackingMethod,
|
AwqBackendPackingMethod,
|
||||||
AwqConfig,
|
AwqConfig,
|
||||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
AWQ_FUSED_MAPPINGS = {
|
AWQ_FUSED_MAPPINGS = {
|
||||||
"mistral": {
|
"mistral": {
|
||||||
@@ -56,6 +57,34 @@ AWQ_FUSED_MAPPINGS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AWQ_SCALES_MAPPINGS = {
|
||||||
|
"starcoder2": {"act": "act", "layer_before_act": "c_fc"},
|
||||||
|
"RefinedWebModel": {"act": "act", "layer_before_act": "dense_h_to_4h"},
|
||||||
|
"falcon": {"act": "act", "layer_before_act": "dense_h_to_4h"},
|
||||||
|
"mpt": {"act": "act", "layer_before_act": "up_proj"},
|
||||||
|
"gptj": {"act": "act", "layer_before_act": "fc_in"},
|
||||||
|
"gpt_neox": {"act": "act", "layer_before_act": "dense_h_to_4h"},
|
||||||
|
"gpt_bigcode": {"act": "act", "layer_before_act": "c_fc"},
|
||||||
|
"bloom": {"act": "gelu_impl", "layer_before_act": "dense_h_to_4h"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def replace_quantization_scales(model, model_type):
|
||||||
|
from awq.modules.act import ScaledActivation
|
||||||
|
|
||||||
|
if model_type not in AWQ_SCALES_MAPPINGS:
|
||||||
|
return model
|
||||||
|
for name, module in model.named_children():
|
||||||
|
act_name = AWQ_SCALES_MAPPINGS[model_type]["act"]
|
||||||
|
layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]
|
||||||
|
if name == act_name and hasattr(model, layer_before_act_name):
|
||||||
|
layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"])
|
||||||
|
size = layer_before_act.out_features
|
||||||
|
scale_like = torch.ones(size)
|
||||||
|
model._modules[name] = ScaledActivation(module, scale_like)
|
||||||
|
_ = replace_quantization_scales(module, model_type)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def replace_with_awq_linear(
|
def replace_with_awq_linear(
|
||||||
model,
|
model,
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class AwqQuantizer(HfQuantizer):
|
|||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||||
from ..integrations import get_keys_to_not_convert, replace_with_awq_linear
|
from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear
|
||||||
|
|
||||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||||
|
|
||||||
@@ -86,6 +86,8 @@ class AwqQuantizer(HfQuantizer):
|
|||||||
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
|
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model = replace_quantization_scales(model, model.config.model_type)
|
||||||
|
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"You are loading an AWQ model but no linear modules were found in your model."
|
"You are loading an AWQ model but no linear modules were found in your model."
|
||||||
|
|||||||
@@ -471,3 +471,22 @@ class AwqFusedTest(unittest.TestCase):
|
|||||||
|
|
||||||
outputs = model.generate(**inputs, max_new_tokens=12)
|
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL)
|
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_auto_awq
|
||||||
|
@require_accelerate
|
||||||
|
class AwqScaleTest(unittest.TestCase):
|
||||||
|
model_name = "TechxGenus/starcoder2-3b-AWQ"
|
||||||
|
|
||||||
|
def test_load_quantized_model(self):
|
||||||
|
from awq.modules.act import ScaledActivation
|
||||||
|
|
||||||
|
"""
|
||||||
|
Simple test that checks if the scales have been replaced in the quantized model
|
||||||
|
"""
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda"
|
||||||
|
)
|
||||||
|
self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))
|
||||||
|
|||||||
Reference in New Issue
Block a user