Fix hqq skipped modules and dynamic quant (#36821)
* Fix hqq skip_modules and dynamic_quant * fix skipped modules loading * add dynamic/skip HqqConfig test
This commit is contained in:
@@ -207,3 +207,36 @@ class HQQSerializationTest(unittest.TestCase):
|
||||
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||
|
||||
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
|
||||
|
||||
def test_model_serialization_dynamic_quant_with_skip(self):
|
||||
"""
|
||||
Simple HQQ LLM save/load test with dynamic quant
|
||||
"""
|
||||
q4_config = {"nbits": 4, "group_size": 64}
|
||||
q3_config = {"nbits": 3, "group_size": 64}
|
||||
|
||||
quant_config = HqqConfig(
|
||||
dynamic_config={
|
||||
"self_attn.q_proj": q4_config,
|
||||
"self_attn.k_proj": q4_config,
|
||||
"self_attn.v_proj": q4_config,
|
||||
"self_attn.o_proj": q4_config,
|
||||
"mlp.gate_proj": q3_config,
|
||||
"mlp.up_proj": q3_config,
|
||||
},
|
||||
skip_modules=["lm_head", "down_proj"],
|
||||
)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
model = hqq_runner.model
|
||||
|
||||
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
|
||||
with torch.no_grad():
|
||||
model.forward(input_tensor).logits
|
||||
|
||||
self.assertEqual(isinstance(model.model.layers[1].mlp.down_proj, torch.nn.Linear), True)
|
||||
self.assertEqual(model.model.layers[1].self_attn.v_proj.quant_config["weight_quant_params"]["nbits"], 4)
|
||||
self.assertEqual(model.model.layers[1].mlp.gate_proj.quant_config["weight_quant_params"]["nbits"], 3)
|
||||
|
||||
Reference in New Issue
Block a user