Hqq serialization (#33141)
* HQQ model serialization attempt * fix hqq dispatch and unexpected keys * style * remove check_old_param * revert to check HQQLinear in quantizer_hqq.py * revert to check HQQLinear in quantizer_hqq.py * update HqqConfig default params * make ci happy * make ci happy * revert to HQQLinear check in quantizer_hqq.py * check hqq_min version 0.2.0 * set axis=1 as default in quantization_config.py * validate_env with hqq>=0.2.0 version message * deprecated hqq kwargs message * make ci happy * remove run_expected_keys_check hack + bump to 0.2.1 min hqq version * fix unexpected_keys hqq update * add pre_quantized check * add update_expected_keys to base quantizerr * ci base.py fix? * ci base.py fix? * fix "quantization typo" src/transformers/utils/quantization_config.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix post merge --------- Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
6
docs/source/en/quantization/hqq.md
Normal file → Executable file
6
docs/source/en/quantization/hqq.md
Normal file → Executable file
@@ -30,13 +30,13 @@ To quantize a model, you need to create an [`HqqConfig`]. There are two ways of
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
||||||
|
|
||||||
# Method 1: all linear layers will use the same quantization config
|
# Method 1: all linear layers will use the same quantization config
|
||||||
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
|
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||||
```
|
```
|
||||||
|
|
||||||
``` Python
|
``` Python
|
||||||
# Method 2: each linear layer with the same tag will use a dedicated quantization config
|
# Method 2: each linear layer with the same tag will use a dedicated quantization config
|
||||||
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
|
q4_config = {'nbits':4, 'group_size':64}
|
||||||
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
|
q3_config = {'nbits':3, 'group_size':32}
|
||||||
quant_config = HqqConfig(dynamic_config={
|
quant_config = HqqConfig(dynamic_config={
|
||||||
'self_attn.q_proj':q4_config,
|
'self_attn.q_proj':q4_config,
|
||||||
'self_attn.k_proj':q4_config,
|
'self_attn.k_proj':q4_config,
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_
|
|||||||
|
|
||||||
has_been_replaced = True
|
has_been_replaced = True
|
||||||
|
|
||||||
|
# Add these fake parameters to avoid loading fail
|
||||||
|
for att in ["W_q", "meta"]:
|
||||||
|
setattr(module, att, None)
|
||||||
|
|
||||||
if len(list(module.children())) > 0:
|
if len(list(module.children())) > 0:
|
||||||
_, has_been_replaced = _prepare_for_hqq_linear(
|
_, has_been_replaced = _prepare_for_hqq_linear(
|
||||||
module,
|
module,
|
||||||
@@ -97,7 +101,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
|
|||||||
|
|
||||||
# Convert quantization_config to layer-wise config
|
# Convert quantization_config to layer-wise config
|
||||||
skip_modules = quantization_config.skip_modules
|
skip_modules = quantization_config.skip_modules
|
||||||
quant_config = quantization_config.to_dict()
|
quant_config = quantization_config.quant_config
|
||||||
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
|
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
|
||||||
|
|
||||||
if any(key in linear_tags for key in quant_config.keys()):
|
if any(key in linear_tags for key in quant_config.keys()):
|
||||||
@@ -113,7 +117,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
|
|||||||
)
|
)
|
||||||
|
|
||||||
# We store quantization config as linear_tag -> hqq quant config
|
# We store quantization config as linear_tag -> hqq quant config
|
||||||
model.config.quantization_config = patch_params
|
model.config.quantization_config = {
|
||||||
|
"quant_config": quant_config,
|
||||||
|
"quant_method": quantization_config.quant_method,
|
||||||
|
"skip_modules": skip_modules,
|
||||||
|
}
|
||||||
|
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
logger.warning("No linear modules were found in your model for quantization.")
|
logger.warning("No linear modules were found in your model for quantization.")
|
||||||
|
|||||||
@@ -934,12 +934,17 @@ def _load_state_dict_into_meta_model(
|
|||||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||||
|
|
||||||
old_param = model
|
old_param = model
|
||||||
splits = param_name.split(".")
|
splits = param_name.split(".")
|
||||||
for split in splits:
|
for split in splits:
|
||||||
old_param = getattr(old_param, split)
|
old_param = getattr(old_param, split)
|
||||||
|
# Not all the attributes of a module are Parameters/Tensor
|
||||||
|
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||||
|
old_param = None
|
||||||
if old_param is None:
|
if old_param is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
if old_param is not None:
|
if old_param is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
param = param.to(old_param.dtype)
|
param = param.to(old_param.dtype)
|
||||||
@@ -3819,6 +3824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
from_pt = not (from_tf | from_flax)
|
from_pt = not (from_tf | from_flax)
|
||||||
|
|
||||||
# load pt weights early so that we know which dtype to init the model under
|
# load pt weights early so that we know which dtype to init the model under
|
||||||
|
|
||||||
if from_pt:
|
if from_pt:
|
||||||
if not is_sharded and state_dict is None:
|
if not is_sharded and state_dict is None:
|
||||||
# Time to load the checkpoint
|
# Time to load the checkpoint
|
||||||
@@ -4176,6 +4182,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
expected_keys = list(model_state_dict.keys())
|
expected_keys = list(model_state_dict.keys())
|
||||||
prefix = model.base_model_prefix
|
prefix = model.base_model_prefix
|
||||||
|
|
||||||
|
if hf_quantizer is not None:
|
||||||
|
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
|
||||||
|
|
||||||
def _fix_key(key):
|
def _fix_key(key):
|
||||||
if "beta" in key:
|
if "beta" in key:
|
||||||
return key.replace("beta", "bias")
|
return key.replace("beta", "bias")
|
||||||
@@ -4290,7 +4299,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
value = torch.empty(*param.size(), dtype=target_dtype)
|
value = torch.empty(*param.size(), dtype=target_dtype)
|
||||||
if (
|
if (
|
||||||
not is_quantized
|
not is_quantized
|
||||||
or getattr(hf_quantizer, "requires_parameters_quantization", False)
|
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
|
||||||
or not hf_quantizer.check_quantized_param(
|
or not hf_quantizer.check_quantized_param(
|
||||||
model, param_value=value, param_name=key, state_dict={}
|
model, param_value=value, param_name=key, state_dict={}
|
||||||
)
|
)
|
||||||
|
|||||||
12
src/transformers/quantizers/base.py
Normal file → Executable file
12
src/transformers/quantizers/base.py
Normal file → Executable file
@@ -109,6 +109,18 @@ class HfQuantizer(ABC):
|
|||||||
"""
|
"""
|
||||||
return missing_keys
|
return missing_keys
|
||||||
|
|
||||||
|
def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Override this method if you want to adjust the `update_expected_keys`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_keys (`List[str]`, *optional*):
|
||||||
|
The list of the expected keys in the initialized model.
|
||||||
|
loaded_keys (`List[str]`, *optional*):
|
||||||
|
The list of the loaded keys in the checkpoint.
|
||||||
|
"""
|
||||||
|
return expected_keys
|
||||||
|
|
||||||
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
|
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
|
||||||
"""
|
"""
|
||||||
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
|
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
def validate_environment(self, *args, **kwargs):
|
def validate_environment(self, *args, **kwargs):
|
||||||
if not (is_hqq_available()):
|
if not (is_hqq_available()):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`"
|
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
||||||
@@ -91,6 +91,65 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
else:
|
else:
|
||||||
self.using_multi_gpu = len(set(device_map.values())) > 1
|
self.using_multi_gpu = len(set(device_map.values())) > 1
|
||||||
|
|
||||||
|
def update_missing_keys(
|
||||||
|
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
|
||||||
|
) -> List[str]:
|
||||||
|
if self.pre_quantized:
|
||||||
|
return [key for key in missing_keys if ("weight" not in key)]
|
||||||
|
else:
|
||||||
|
return missing_keys
|
||||||
|
|
||||||
|
# Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
|
||||||
|
def update_expected_keys(
|
||||||
|
self, model: "PreTrainedModel", expected_keys: List[str], loaded_keys: List[str]
|
||||||
|
) -> List[str]:
|
||||||
|
if not self.pre_quantized:
|
||||||
|
return expected_keys
|
||||||
|
|
||||||
|
# Collects all quantizable (linear) layers
|
||||||
|
def _find_hqq_quantizable_layers(model, layers):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if isinstance(module, (torch.nn.Linear)):
|
||||||
|
layers.add(module.name)
|
||||||
|
_find_hqq_quantizable_layers(module, layers)
|
||||||
|
|
||||||
|
new_keys = set(expected_keys)
|
||||||
|
if is_hqq_available():
|
||||||
|
from hqq.core.quantize import HQQLinear
|
||||||
|
|
||||||
|
# Name modules
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
module.name = name
|
||||||
|
|
||||||
|
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
|
||||||
|
_valid_modules = set()
|
||||||
|
_find_hqq_quantizable_layers(model, _valid_modules)
|
||||||
|
_valid_modules -= set(model.config.quantization_config["skip_modules"])
|
||||||
|
|
||||||
|
# Append new expected layers based on _ref_keys
|
||||||
|
_ref_keys = HQQLinear(
|
||||||
|
linear_layer=None, quant_config=None, compute_dtype=torch.float16, device="cpu"
|
||||||
|
).state_dict_keys() - {"bias"}
|
||||||
|
|
||||||
|
# Clean-up
|
||||||
|
_rm_keys = set()
|
||||||
|
for key in new_keys:
|
||||||
|
if any(_module in key for _module in _valid_modules):
|
||||||
|
_rm_keys.add(key)
|
||||||
|
new_keys -= _rm_keys
|
||||||
|
# At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
|
||||||
|
|
||||||
|
# Re-populate Linear/HQQLinear
|
||||||
|
for _module in _valid_modules:
|
||||||
|
if _module + ".weight" in loaded_keys:
|
||||||
|
new_keys.add(_module + ".weight")
|
||||||
|
else:
|
||||||
|
new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
|
||||||
|
if _module + ".bias" in loaded_keys:
|
||||||
|
new_keys.add(_module + ".bias")
|
||||||
|
|
||||||
|
return list(new_keys)
|
||||||
|
|
||||||
def check_quantized_param(
|
def check_quantized_param(
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
@@ -99,9 +158,18 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
state_dict: Dict[str, Any],
|
state_dict: Dict[str, Any],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
if is_hqq_available():
|
||||||
|
from hqq.core.quantize import HQQLinear
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
|
||||||
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
|
if self.pre_quantized:
|
||||||
|
return (
|
||||||
|
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
|
||||||
|
and tensor_name != "weight"
|
||||||
|
and tensor_name != "bias"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"
|
||||||
|
|
||||||
def create_quantized_param(
|
def create_quantized_param(
|
||||||
self,
|
self,
|
||||||
@@ -122,13 +190,43 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
from hqq.core.quantize import HQQLinear
|
from hqq.core.quantize import HQQLinear
|
||||||
|
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
layer_name = ".".join(param_name.split(".")[:-1])
|
||||||
layer_name = param_name.replace(".weight", "").replace(".bias", "")
|
|
||||||
parent_module = find_parent(model, layer_name)
|
parent_module = find_parent(model, layer_name)
|
||||||
node = layer_name.split(".")[-1]
|
node = layer_name.split(".")[-1]
|
||||||
|
|
||||||
# Step 0: set module state_dict
|
# set module state_dict
|
||||||
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
|
module_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if layer_name + "." in k:
|
||||||
|
module_state_dict[k.split(".")[-1]] = v
|
||||||
|
if unexpected_keys is not None and k in unexpected_keys:
|
||||||
|
unexpected_keys.remove(k)
|
||||||
|
|
||||||
|
if self.pre_quantized:
|
||||||
|
if isinstance(module, HQQLinear):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
hqq_layer = HQQLinear(
|
||||||
|
linear_layer=None,
|
||||||
|
quant_config=None,
|
||||||
|
compute_dtype=self.torch_dtype,
|
||||||
|
device=target_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
hqq_layer.load_state_dict(module_state_dict)
|
||||||
|
|
||||||
|
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
|
||||||
|
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
|
||||||
|
|
||||||
|
if self.using_multi_gpu:
|
||||||
|
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
|
||||||
|
|
||||||
|
setattr(parent_module, node, hqq_layer)
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
del module.__dict__, module
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return
|
||||||
|
|
||||||
# Step 1: populate module with weight/bias from module state dict
|
# Step 1: populate module with weight/bias from module state dict
|
||||||
for key in module_state_dict:
|
for key in module_state_dict:
|
||||||
@@ -136,7 +234,6 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
|
|
||||||
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
||||||
# directly doesn't work.
|
# directly doesn't work.
|
||||||
|
|
||||||
if hasattr(module, "quant_config"):
|
if hasattr(module, "quant_config"):
|
||||||
hqq_layer = HQQLinear(
|
hqq_layer = HQQLinear(
|
||||||
module,
|
module,
|
||||||
@@ -192,7 +289,7 @@ class HqqHfQuantizer(HfQuantizer):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def is_serializable(self, safe_serialization=None):
|
def is_serializable(self, safe_serialization=None):
|
||||||
return False
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self) -> bool:
|
def is_trainable(self) -> bool:
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ ACCELERATE_MIN_VERSION = "0.26.0"
|
|||||||
FSDP_MIN_VERSION = "1.12.0"
|
FSDP_MIN_VERSION = "1.12.0"
|
||||||
GGUF_MIN_VERSION = "0.10.0"
|
GGUF_MIN_VERSION = "0.10.0"
|
||||||
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
||||||
|
HQQ_MIN_VERSION = "0.2.1"
|
||||||
|
|
||||||
|
|
||||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||||
@@ -181,7 +182,7 @@ _torchao_available = _is_package_available("torchao")
|
|||||||
_torchdistx_available = _is_package_available("torchdistx")
|
_torchdistx_available = _is_package_available("torchdistx")
|
||||||
_torchvision_available = _is_package_available("torchvision")
|
_torchvision_available = _is_package_available("torchvision")
|
||||||
_mlx_available = _is_package_available("mlx")
|
_mlx_available = _is_package_available("mlx")
|
||||||
_hqq_available = _is_package_available("hqq")
|
_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True)
|
||||||
_tiktoken_available = _is_package_available("tiktoken")
|
_tiktoken_available = _is_package_available("tiktoken")
|
||||||
_blobfile_available = _is_package_available("blobfile")
|
_blobfile_available = _is_package_available("blobfile")
|
||||||
_liger_kernel_available = _is_package_available("liger_kernel")
|
_liger_kernel_available = _is_package_available("liger_kernel")
|
||||||
@@ -323,8 +324,8 @@ def is_torch_deterministic():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_hqq_available():
|
def is_hqq_available(min_version: str = HQQ_MIN_VERSION):
|
||||||
return _hqq_available
|
return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version)
|
||||||
|
|
||||||
|
|
||||||
def is_pygments_available():
|
def is_pygments_available():
|
||||||
|
|||||||
@@ -193,15 +193,9 @@ class HqqConfig(QuantizationConfigMixin):
|
|||||||
Number of bits. Supported values are (8, 4, 3, 2, 1).
|
Number of bits. Supported values are (8, 4, 3, 2, 1).
|
||||||
group_size (`int`, *optional*, defaults to 64):
|
group_size (`int`, *optional*, defaults to 64):
|
||||||
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
|
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
|
||||||
quant_zero (`bool`, *optional*, defaults to `True`):
|
|
||||||
Quantize the zero-point if set to `True`.
|
|
||||||
quant_scale (`bool`, *optional*, defaults to `False`):
|
|
||||||
Quantize the scaling if set to `True`.
|
|
||||||
offload_meta (`bool`, *optional*, defaults to `False`):
|
|
||||||
Offload the meta-data to the CPU if set to `True`.
|
|
||||||
view_as_float (`bool`, *optional*, defaults to `False`):
|
view_as_float (`bool`, *optional*, defaults to `False`):
|
||||||
View the quantized weight as float (used in distributed training) if set to `True`.
|
View the quantized weight as float (used in distributed training) if set to `True`.
|
||||||
axis (`int`, *optional*, defaults to 0):
|
axis (`Optional[int]`, *optional*):
|
||||||
Axis along which grouping is performed. Supported values are 0 or 1.
|
Axis along which grouping is performed. Supported values are 0 or 1.
|
||||||
dynamic_config (dict, *optional*):
|
dynamic_config (dict, *optional*):
|
||||||
Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
|
Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
|
||||||
@@ -216,11 +210,8 @@ class HqqConfig(QuantizationConfigMixin):
|
|||||||
self,
|
self,
|
||||||
nbits: int = 4,
|
nbits: int = 4,
|
||||||
group_size: int = 64,
|
group_size: int = 64,
|
||||||
quant_zero: bool = True,
|
|
||||||
quant_scale: bool = False,
|
|
||||||
offload_meta: bool = False,
|
|
||||||
view_as_float: bool = False,
|
view_as_float: bool = False,
|
||||||
axis: int = 0,
|
axis: Optional[int] = None,
|
||||||
dynamic_config: Optional[dict] = None,
|
dynamic_config: Optional[dict] = None,
|
||||||
skip_modules: List[str] = ["lm_head"],
|
skip_modules: List[str] = ["lm_head"],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -228,6 +219,16 @@ class HqqConfig(QuantizationConfigMixin):
|
|||||||
if is_hqq_available():
|
if is_hqq_available():
|
||||||
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
|
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
|
||||||
|
|
||||||
|
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
|
||||||
|
if deprecated_key in kwargs:
|
||||||
|
logger.info(
|
||||||
|
deprecated_key + " is deprecated. This parameter will be ignored in quantization settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
if axis is None:
|
||||||
|
axis = 1
|
||||||
|
logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.")
|
||||||
|
|
||||||
if axis not in [0, 1]:
|
if axis not in [0, 1]:
|
||||||
raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
|
raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
|
||||||
|
|
||||||
@@ -240,9 +241,6 @@ class HqqConfig(QuantizationConfigMixin):
|
|||||||
**{
|
**{
|
||||||
"nbits": nbits,
|
"nbits": nbits,
|
||||||
"group_size": group_size,
|
"group_size": group_size,
|
||||||
"quant_zero": quant_zero,
|
|
||||||
"quant_scale": quant_scale,
|
|
||||||
"offload_meta": offload_meta,
|
|
||||||
"view_as_float": view_as_float,
|
"view_as_float": view_as_float,
|
||||||
"axis": axis,
|
"axis": axis,
|
||||||
}
|
}
|
||||||
@@ -259,12 +257,26 @@ class HqqConfig(QuantizationConfigMixin):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
|
||||||
|
"""
|
||||||
|
instance = cls()
|
||||||
|
instance.quant_config = config["quant_config"]
|
||||||
|
instance.skip_modules = config["skip_modules"]
|
||||||
|
return instance
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Serializes this instance to a Python dictionary. Returns:
|
Serializes this instance to a Python dictionary. Returns:
|
||||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||||
"""
|
"""
|
||||||
return self.quant_config
|
return {
|
||||||
|
"quant_config": self.quant_config,
|
||||||
|
"quant_method": self.quant_method,
|
||||||
|
"skip_modules": self.skip_modules,
|
||||||
|
}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
config_dict = self.to_dict()
|
config_dict = self.to_dict()
|
||||||
|
|||||||
@@ -94,8 +94,7 @@ class HqqConfigTest(unittest.TestCase):
|
|||||||
quantization_config = HqqConfig()
|
quantization_config = HqqConfig()
|
||||||
hqq_orig_config = quantization_config.to_dict()
|
hqq_orig_config = quantization_config.to_dict()
|
||||||
|
|
||||||
for key in hqq_orig_config:
|
self.assertEqual(quantization_config.quant_config, hqq_orig_config["quant_config"])
|
||||||
self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key])
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -109,32 +108,7 @@ class HQQTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Simple LLM model testing fp16
|
Simple LLM model testing fp16
|
||||||
"""
|
"""
|
||||||
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
|
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||||
|
|
||||||
hqq_runner = HQQLLMRunner(
|
|
||||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
|
||||||
check_forward(self, hqq_runner.model)
|
|
||||||
|
|
||||||
def test_f16_quantized_model_with_offloading(self):
|
|
||||||
"""
|
|
||||||
Simple LLM model testing bfp16 with meta-data offloading
|
|
||||||
"""
|
|
||||||
q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False}
|
|
||||||
q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True}
|
|
||||||
quant_config = HqqConfig(
|
|
||||||
dynamic_config={
|
|
||||||
"self_attn.q_proj": q4_config,
|
|
||||||
"self_attn.k_proj": q4_config,
|
|
||||||
"self_attn.v_proj": q4_config,
|
|
||||||
"self_attn.o_proj": q4_config,
|
|
||||||
"mlp.gate_proj": q3_config,
|
|
||||||
"mlp.up_proj": q3_config,
|
|
||||||
"mlp.down_proj": q3_config,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
hqq_runner = HQQLLMRunner(
|
hqq_runner = HQQLLMRunner(
|
||||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||||
@@ -157,7 +131,7 @@ class HQQTestMultiGPU(unittest.TestCase):
|
|||||||
Simple LLM model testing fp16 with multi-gpu
|
Simple LLM model testing fp16 with multi-gpu
|
||||||
"""
|
"""
|
||||||
|
|
||||||
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
|
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||||
|
|
||||||
hqq_runner = HQQLLMRunner(
|
hqq_runner = HQQLLMRunner(
|
||||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
|
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
|
||||||
@@ -165,3 +139,44 @@ class HQQTestMultiGPU(unittest.TestCase):
|
|||||||
|
|
||||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||||
check_forward(self, hqq_runner.model)
|
check_forward(self, hqq_runner.model)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_accelerate
|
||||||
|
class HQQSerializationTest(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
def test_model_serialization(self):
|
||||||
|
"""
|
||||||
|
Simple HQQ LLM save/load test
|
||||||
|
"""
|
||||||
|
quant_config = HqqConfig(nbits=4, group_size=64)
|
||||||
|
|
||||||
|
hqq_runner = HQQLLMRunner(
|
||||||
|
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_ref = hqq_runner.model.forward(input_tensor).logits
|
||||||
|
|
||||||
|
# Save
|
||||||
|
saved_model_id = "quant_model"
|
||||||
|
hqq_runner.model.save_pretrained(saved_model_id)
|
||||||
|
|
||||||
|
# Remove old model
|
||||||
|
del hqq_runner.model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Load and check if the logits match
|
||||||
|
model_loaded = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"quant_model", torch_dtype=torch.float16, device_map=torch_device, low_cpu_mem_usage=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||||
|
|
||||||
|
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user