Fix hqq skipped modules and dynamic quant (#36821)

* Fix hqq skip_modules and dynamic_quant

* fix skipped modules loading

* add dynamic/skip HqqConfig test
This commit is contained in:
mobicham
2025-03-20 15:31:49 +01:00
committed by GitHub
parent 055afdb6bb
commit 3e8f0fbf44
2 changed files with 57 additions and 3 deletions

View File

@@ -207,3 +207,36 @@ class HQQSerializationTest(unittest.TestCase):
logits_loaded = model_loaded.forward(input_tensor).logits
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
def test_model_serialization_dynamic_quant_with_skip(self):
"""
Simple HQQ LLM save/load test with dynamic quant
"""
q4_config = {"nbits": 4, "group_size": 64}
q3_config = {"nbits": 3, "group_size": 64}
quant_config = HqqConfig(
dynamic_config={
"self_attn.q_proj": q4_config,
"self_attn.k_proj": q4_config,
"self_attn.v_proj": q4_config,
"self_attn.o_proj": q4_config,
"mlp.gate_proj": q3_config,
"mlp.up_proj": q3_config,
},
skip_modules=["lm_head", "down_proj"],
)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
model = hqq_runner.model
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
with torch.no_grad():
model.forward(input_tensor).logits
self.assertEqual(isinstance(model.model.layers[1].mlp.down_proj, torch.nn.Linear), True)
self.assertEqual(model.model.layers[1].self_attn.v_proj.quant_config["weight_quant_params"]["nbits"], 4)
self.assertEqual(model.model.layers[1].mlp.gate_proj.quant_config["weight_quant_params"]["nbits"], 3)