From 3e8f0fbf44f8223a50ce05864a7cea223085fb9a Mon Sep 17 00:00:00 2001 From: mobicham <37179323+mobicham@users.noreply.github.com> Date: Thu, 20 Mar 2025 15:31:49 +0100 Subject: [PATCH] Fix hqq skipped modules and dynamic quant (#36821) * Fix hqq skip_modules and dynamic_quant * fix skipped modules loading * add dynamic/skip HqqConfig test --- src/transformers/quantizers/quantizer_hqq.py | 27 ++++++++++++++-- tests/quantization/hqq/test_hqq.py | 33 ++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 60d334fdd9..8524e7dcec 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -124,7 +124,14 @@ class HqqHfQuantizer(HfQuantizer): # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params _valid_modules = set() _find_hqq_quantizable_layers(model, _valid_modules) - _valid_modules -= set(model.config.quantization_config["skip_modules"]) + + # Remove skipped modules + _skipped_modules = set() + for _module in _valid_modules: + for _skip_module in model.config.quantization_config["skip_modules"]: + if _skip_module in _module: + _skipped_modules.add(_module) + _valid_modules -= _skipped_modules # Append new expected layers based on _ref_keys _ref_keys = HQQLinear( @@ -243,10 +250,24 @@ class HqqHfQuantizer(HfQuantizer): # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module # directly doesn't work. - if hasattr(module, "quant_config"): + quant_config = model.config.quantization_config["quant_config"] + skip_modules = model.config.quantization_config["skip_modules"] + module_tag = ".".join(module.name.split(".")[-2:]) + module_quant_config = None + if "weight_quant_params" in quant_config: + module_quant_config = quant_config + elif module_tag in quant_config: + module_quant_config = quant_config[module_tag] + + for skip_module in skip_modules: + if skip_module in module.name: + module_quant_config = None + break + + if module_quant_config is not None: hqq_layer = HQQLinear( module, - module.quant_config, + quant_config=module_quant_config, compute_dtype=self.torch_dtype, device=target_device, del_orig=True, diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 7335a93708..031b3fefa5 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -207,3 +207,36 @@ class HQQSerializationTest(unittest.TestCase): logits_loaded = model_loaded.forward(input_tensor).logits self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0) + + def test_model_serialization_dynamic_quant_with_skip(self): + """ + Simple HQQ LLM save/load test with dynamic quant + """ + q4_config = {"nbits": 4, "group_size": 64} + q3_config = {"nbits": 3, "group_size": 64} + + quant_config = HqqConfig( + dynamic_config={ + "self_attn.q_proj": q4_config, + "self_attn.k_proj": q4_config, + "self_attn.v_proj": q4_config, + "self_attn.o_proj": q4_config, + "mlp.gate_proj": q3_config, + "mlp.up_proj": q3_config, + }, + skip_modules=["lm_head", "down_proj"], + ) + + hqq_runner = HQQLLMRunner( + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device + ) + + model = hqq_runner.model + + input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device) + with torch.no_grad(): + model.forward(input_tensor).logits + + self.assertEqual(isinstance(model.model.layers[1].mlp.down_proj, torch.nn.Linear), True) + self.assertEqual(model.model.layers[1].self_attn.v_proj.quant_config["weight_quant_params"]["nbits"], 4) + self.assertEqual(model.model.layers[1].mlp.gate_proj.quant_config["weight_quant_params"]["nbits"], 3)