Hqq serialization (#33141)
* HQQ model serialization attempt * fix hqq dispatch and unexpected keys * style * remove check_old_param * revert to check HQQLinear in quantizer_hqq.py * revert to check HQQLinear in quantizer_hqq.py * update HqqConfig default params * make ci happy * make ci happy * revert to HQQLinear check in quantizer_hqq.py * check hqq_min version 0.2.0 * set axis=1 as default in quantization_config.py * validate_env with hqq>=0.2.0 version message * deprecated hqq kwargs message * make ci happy * remove run_expected_keys_check hack + bump to 0.2.1 min hqq version * fix unexpected_keys hqq update * add pre_quantized check * add update_expected_keys to base quantizerr * ci base.py fix? * ci base.py fix? * fix "quantization typo" src/transformers/utils/quantization_config.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix post merge --------- Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -94,8 +94,7 @@ class HqqConfigTest(unittest.TestCase):
|
||||
quantization_config = HqqConfig()
|
||||
hqq_orig_config = quantization_config.to_dict()
|
||||
|
||||
for key in hqq_orig_config:
|
||||
self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key])
|
||||
self.assertEqual(quantization_config.quant_config, hqq_orig_config["quant_config"])
|
||||
|
||||
|
||||
@slow
|
||||
@@ -109,32 +108,7 @@ class HQQTest(unittest.TestCase):
|
||||
"""
|
||||
Simple LLM model testing fp16
|
||||
"""
|
||||
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
def test_f16_quantized_model_with_offloading(self):
|
||||
"""
|
||||
Simple LLM model testing bfp16 with meta-data offloading
|
||||
"""
|
||||
q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False}
|
||||
q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True}
|
||||
quant_config = HqqConfig(
|
||||
dynamic_config={
|
||||
"self_attn.q_proj": q4_config,
|
||||
"self_attn.k_proj": q4_config,
|
||||
"self_attn.v_proj": q4_config,
|
||||
"self_attn.o_proj": q4_config,
|
||||
"mlp.gate_proj": q3_config,
|
||||
"mlp.up_proj": q3_config,
|
||||
"mlp.down_proj": q3_config,
|
||||
}
|
||||
)
|
||||
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
@@ -157,7 +131,7 @@ class HQQTestMultiGPU(unittest.TestCase):
|
||||
Simple LLM model testing fp16 with multi-gpu
|
||||
"""
|
||||
|
||||
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
|
||||
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
|
||||
@@ -165,3 +139,44 @@ class HQQTestMultiGPU(unittest.TestCase):
|
||||
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_accelerate
|
||||
class HQQSerializationTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
cleanup()
|
||||
|
||||
def test_model_serialization(self):
|
||||
"""
|
||||
Simple HQQ LLM save/load test
|
||||
"""
|
||||
quant_config = HqqConfig(nbits=4, group_size=64)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_ref = hqq_runner.model.forward(input_tensor).logits
|
||||
|
||||
# Save
|
||||
saved_model_id = "quant_model"
|
||||
hqq_runner.model.save_pretrained(saved_model_id)
|
||||
|
||||
# Remove old model
|
||||
del hqq_runner.model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load and check if the logits match
|
||||
model_loaded = AutoModelForCausalLM.from_pretrained(
|
||||
"quant_model", torch_dtype=torch.float16, device_map=torch_device, low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||
|
||||
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
|
||||
|
||||
Reference in New Issue
Block a user