Hqq serialization (#33141)

* HQQ model serialization attempt

* fix hqq dispatch and unexpected keys

* style

* remove check_old_param

* revert to check HQQLinear in quantizer_hqq.py

* revert to check HQQLinear in quantizer_hqq.py

* update HqqConfig default params

* make ci happy

* make ci happy

* revert to HQQLinear check in quantizer_hqq.py

* check hqq_min version 0.2.0

* set axis=1 as default in quantization_config.py

* validate_env with hqq>=0.2.0 version message

* deprecated hqq kwargs message

* make ci happy

* remove run_expected_keys_check hack + bump to 0.2.1 min hqq version

* fix unexpected_keys hqq update

* add pre_quantized check

* add update_expected_keys to base quantizerr

* ci base.py fix?

* ci base.py fix?

* fix "quantization typo" src/transformers/utils/quantization_config.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix post merge

---------

Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
mobicham
2024-09-30 14:47:18 +02:00
committed by GitHub
parent 4d5b458704
commit f5247aca01
8 changed files with 215 additions and 61 deletions

6
docs/source/en/quantization/hqq.md Normal file → Executable file
View File

@@ -30,13 +30,13 @@ To quantize a model, you need to create an [`HqqConfig`]. There are two ways of
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
# Method 1: all linear layers will use the same quantization config
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
quant_config = HqqConfig(nbits=8, group_size=64)
```
``` Python
# Method 2: each linear layer with the same tag will use a dedicated quantization config
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
q4_config = {'nbits':4, 'group_size':64}
q3_config = {'nbits':3, 'group_size':32}
quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,

View File

@@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_
has_been_replaced = True
# Add these fake parameters to avoid loading fail
for att in ["W_q", "meta"]:
setattr(module, att, None)
if len(list(module.children())) > 0:
_, has_been_replaced = _prepare_for_hqq_linear(
module,
@@ -97,7 +101,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
quant_config = quantization_config.quant_config
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
if any(key in linear_tags for key in quant_config.keys()):
@@ -113,7 +117,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
)
# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params
model.config.quantization_config = {
"quant_config": quant_config,
"quant_method": quantization_config.quant_method,
"skip_modules": skip_modules,
}
if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")

View File

@@ -934,12 +934,17 @@ def _load_state_dict_into_meta_model(
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
# Not all the attributes of a module are Parameters/Tensor
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
if old_param is None:
break
if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
@@ -3819,6 +3824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
from_pt = not (from_tf | from_flax)
# load pt weights early so that we know which dtype to init the model under
if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
@@ -4176,6 +4182,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
def _fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
@@ -4290,7 +4299,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
or getattr(hf_quantizer, "requires_parameters_quantization", False)
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
)

12
src/transformers/quantizers/base.py Normal file → Executable file
View File

@@ -109,6 +109,18 @@ class HfQuantizer(ABC):
"""
return missing_keys
def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]:
"""
Override this method if you want to adjust the `update_expected_keys`.
Args:
expected_keys (`List[str]`, *optional*):
The list of the expected keys in the initialized model.
loaded_keys (`List[str]`, *optional*):
The list of the loaded keys in the checkpoint.
"""
return expected_keys
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case

View File

@@ -62,7 +62,7 @@ class HqqHfQuantizer(HfQuantizer):
def validate_environment(self, *args, **kwargs):
if not (is_hqq_available()):
raise ImportError(
"HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`"
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
)
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
@@ -91,6 +91,65 @@ class HqqHfQuantizer(HfQuantizer):
else:
self.using_multi_gpu = len(set(device_map.values())) > 1
def update_missing_keys(
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
) -> List[str]:
if self.pre_quantized:
return [key for key in missing_keys if ("weight" not in key)]
else:
return missing_keys
# Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
def update_expected_keys(
self, model: "PreTrainedModel", expected_keys: List[str], loaded_keys: List[str]
) -> List[str]:
if not self.pre_quantized:
return expected_keys
# Collects all quantizable (linear) layers
def _find_hqq_quantizable_layers(model, layers):
for name, module in model.named_children():
if isinstance(module, (torch.nn.Linear)):
layers.add(module.name)
_find_hqq_quantizable_layers(module, layers)
new_keys = set(expected_keys)
if is_hqq_available():
from hqq.core.quantize import HQQLinear
# Name modules
for name, module in model.named_modules():
module.name = name
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
_valid_modules = set()
_find_hqq_quantizable_layers(model, _valid_modules)
_valid_modules -= set(model.config.quantization_config["skip_modules"])
# Append new expected layers based on _ref_keys
_ref_keys = HQQLinear(
linear_layer=None, quant_config=None, compute_dtype=torch.float16, device="cpu"
).state_dict_keys() - {"bias"}
# Clean-up
_rm_keys = set()
for key in new_keys:
if any(_module in key for _module in _valid_modules):
_rm_keys.add(key)
new_keys -= _rm_keys
# At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
# Re-populate Linear/HQQLinear
for _module in _valid_modules:
if _module + ".weight" in loaded_keys:
new_keys.add(_module + ".weight")
else:
new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
if _module + ".bias" in loaded_keys:
new_keys.add(_module + ".bias")
return list(new_keys)
def check_quantized_param(
self,
model: "PreTrainedModel",
@@ -99,9 +158,18 @@ class HqqHfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
if is_hqq_available():
from hqq.core.quantize import HQQLinear
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
if self.pre_quantized:
return (
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
and tensor_name != "weight"
and tensor_name != "bias"
)
else:
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"
def create_quantized_param(
self,
@@ -122,13 +190,43 @@ class HqqHfQuantizer(HfQuantizer):
from hqq.core.quantize import HQQLinear
module, tensor_name = get_module_from_name(model, param_name)
layer_name = param_name.replace(".weight", "").replace(".bias", "")
layer_name = ".".join(param_name.split(".")[:-1])
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]
# Step 0: set module state_dict
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
# set module state_dict
module_state_dict = {}
for k, v in state_dict.items():
if layer_name + "." in k:
module_state_dict[k.split(".")[-1]] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
if self.pre_quantized:
if isinstance(module, HQQLinear):
return
else:
hqq_layer = HQQLinear(
linear_layer=None,
quant_config=None,
compute_dtype=self.torch_dtype,
device=target_device,
)
hqq_layer.load_state_dict(module_state_dict)
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
setattr(parent_module, node, hqq_layer)
# cleanup
del module.__dict__, module
torch.cuda.empty_cache()
return
# Step 1: populate module with weight/bias from module state dict
for key in module_state_dict:
@@ -136,7 +234,6 @@ class HqqHfQuantizer(HfQuantizer):
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.
if hasattr(module, "quant_config"):
hqq_layer = HQQLinear(
module,
@@ -192,7 +289,7 @@ class HqqHfQuantizer(HfQuantizer):
return model
def is_serializable(self, safe_serialization=None):
return False
return True
@property
def is_trainable(self) -> bool:

View File

@@ -92,6 +92,7 @@ ACCELERATE_MIN_VERSION = "0.26.0"
FSDP_MIN_VERSION = "1.12.0"
GGUF_MIN_VERSION = "0.10.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
HQQ_MIN_VERSION = "0.2.1"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
@@ -181,7 +182,7 @@ _torchao_available = _is_package_available("torchao")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
_hqq_available = _is_package_available("hqq")
_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True)
_tiktoken_available = _is_package_available("tiktoken")
_blobfile_available = _is_package_available("blobfile")
_liger_kernel_available = _is_package_available("liger_kernel")
@@ -323,8 +324,8 @@ def is_torch_deterministic():
return True
def is_hqq_available():
return _hqq_available
def is_hqq_available(min_version: str = HQQ_MIN_VERSION):
return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version)
def is_pygments_available():

View File

@@ -193,15 +193,9 @@ class HqqConfig(QuantizationConfigMixin):
Number of bits. Supported values are (8, 4, 3, 2, 1).
group_size (`int`, *optional*, defaults to 64):
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
quant_zero (`bool`, *optional*, defaults to `True`):
Quantize the zero-point if set to `True`.
quant_scale (`bool`, *optional*, defaults to `False`):
Quantize the scaling if set to `True`.
offload_meta (`bool`, *optional*, defaults to `False`):
Offload the meta-data to the CPU if set to `True`.
view_as_float (`bool`, *optional*, defaults to `False`):
View the quantized weight as float (used in distributed training) if set to `True`.
axis (`int`, *optional*, defaults to 0):
axis (`Optional[int]`, *optional*):
Axis along which grouping is performed. Supported values are 0 or 1.
dynamic_config (dict, *optional*):
Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
@@ -216,11 +210,8 @@ class HqqConfig(QuantizationConfigMixin):
self,
nbits: int = 4,
group_size: int = 64,
quant_zero: bool = True,
quant_scale: bool = False,
offload_meta: bool = False,
view_as_float: bool = False,
axis: int = 0,
axis: Optional[int] = None,
dynamic_config: Optional[dict] = None,
skip_modules: List[str] = ["lm_head"],
**kwargs,
@@ -228,6 +219,16 @@ class HqqConfig(QuantizationConfigMixin):
if is_hqq_available():
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
if deprecated_key in kwargs:
logger.info(
deprecated_key + " is deprecated. This parameter will be ignored in quantization settings."
)
if axis is None:
axis = 1
logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.")
if axis not in [0, 1]:
raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
@@ -240,9 +241,6 @@ class HqqConfig(QuantizationConfigMixin):
**{
"nbits": nbits,
"group_size": group_size,
"quant_zero": quant_zero,
"quant_scale": quant_scale,
"offload_meta": offload_meta,
"view_as_float": view_as_float,
"axis": axis,
}
@@ -259,12 +257,26 @@ class HqqConfig(QuantizationConfigMixin):
"""
pass
@classmethod
def from_dict(cls, config: Dict[str, Any]):
"""
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
"""
instance = cls()
instance.quant_config = config["quant_config"]
instance.skip_modules = config["skip_modules"]
return instance
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return self.quant_config
return {
"quant_config": self.quant_config,
"quant_method": self.quant_method,
"skip_modules": self.skip_modules,
}
def __repr__(self):
config_dict = self.to_dict()

View File

@@ -94,8 +94,7 @@ class HqqConfigTest(unittest.TestCase):
quantization_config = HqqConfig()
hqq_orig_config = quantization_config.to_dict()
for key in hqq_orig_config:
self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key])
self.assertEqual(quantization_config.quant_config, hqq_orig_config["quant_config"])
@slow
@@ -109,32 +108,7 @@ class HQQTest(unittest.TestCase):
"""
Simple LLM model testing fp16
"""
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_f16_quantized_model_with_offloading(self):
"""
Simple LLM model testing bfp16 with meta-data offloading
"""
q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False}
q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True}
quant_config = HqqConfig(
dynamic_config={
"self_attn.q_proj": q4_config,
"self_attn.k_proj": q4_config,
"self_attn.v_proj": q4_config,
"self_attn.o_proj": q4_config,
"mlp.gate_proj": q3_config,
"mlp.up_proj": q3_config,
"mlp.down_proj": q3_config,
}
)
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
@@ -157,7 +131,7 @@ class HQQTestMultiGPU(unittest.TestCase):
Simple LLM model testing fp16 with multi-gpu
"""
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
@@ -165,3 +139,44 @@ class HQQTestMultiGPU(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
@slow
@require_torch_gpu
@require_accelerate
class HQQSerializationTest(unittest.TestCase):
def tearDown(self):
cleanup()
def test_model_serialization(self):
"""
Simple HQQ LLM save/load test
"""
quant_config = HqqConfig(nbits=4, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
with torch.no_grad():
logits_ref = hqq_runner.model.forward(input_tensor).logits
# Save
saved_model_id = "quant_model"
hqq_runner.model.save_pretrained(saved_model_id)
# Remove old model
del hqq_runner.model
torch.cuda.empty_cache()
# Load and check if the logits match
model_loaded = AutoModelForCausalLM.from_pretrained(
"quant_model", torch_dtype=torch.float16, device_map=torch_device, low_cpu_mem_usage=True
)
with torch.no_grad():
logits_loaded = model_loaded.forward(input_tensor).logits
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)