Hqq serialization (#33141)

* HQQ model serialization attempt

* fix hqq dispatch and unexpected keys

* style

* remove check_old_param

* revert to check HQQLinear in quantizer_hqq.py

* revert to check HQQLinear in quantizer_hqq.py

* update HqqConfig default params

* make ci happy

* make ci happy

* revert to HQQLinear check in quantizer_hqq.py

* check hqq_min version 0.2.0

* set axis=1 as default in quantization_config.py

* validate_env with hqq>=0.2.0 version message

* deprecated hqq kwargs message

* make ci happy

* remove run_expected_keys_check hack + bump to 0.2.1 min hqq version

* fix unexpected_keys hqq update

* add pre_quantized check

* add update_expected_keys to base quantizerr

* ci base.py fix?

* ci base.py fix?

* fix "quantization typo" src/transformers/utils/quantization_config.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix post merge

---------

Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
mobicham
2024-09-30 14:47:18 +02:00
committed by GitHub
parent 4d5b458704
commit f5247aca01
8 changed files with 215 additions and 61 deletions

View File

@@ -94,8 +94,7 @@ class HqqConfigTest(unittest.TestCase):
quantization_config = HqqConfig()
hqq_orig_config = quantization_config.to_dict()
for key in hqq_orig_config:
self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key])
self.assertEqual(quantization_config.quant_config, hqq_orig_config["quant_config"])
@slow
@@ -109,32 +108,7 @@ class HQQTest(unittest.TestCase):
"""
Simple LLM model testing fp16
"""
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_f16_quantized_model_with_offloading(self):
"""
Simple LLM model testing bfp16 with meta-data offloading
"""
q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False}
q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True}
quant_config = HqqConfig(
dynamic_config={
"self_attn.q_proj": q4_config,
"self_attn.k_proj": q4_config,
"self_attn.v_proj": q4_config,
"self_attn.o_proj": q4_config,
"mlp.gate_proj": q3_config,
"mlp.up_proj": q3_config,
"mlp.down_proj": q3_config,
}
)
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
@@ -157,7 +131,7 @@ class HQQTestMultiGPU(unittest.TestCase):
Simple LLM model testing fp16 with multi-gpu
"""
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
@@ -165,3 +139,44 @@ class HQQTestMultiGPU(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
@slow
@require_torch_gpu
@require_accelerate
class HQQSerializationTest(unittest.TestCase):
def tearDown(self):
cleanup()
def test_model_serialization(self):
"""
Simple HQQ LLM save/load test
"""
quant_config = HqqConfig(nbits=4, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
with torch.no_grad():
logits_ref = hqq_runner.model.forward(input_tensor).logits
# Save
saved_model_id = "quant_model"
hqq_runner.model.save_pretrained(saved_model_id)
# Remove old model
del hqq_runner.model
torch.cuda.empty_cache()
# Load and check if the logits match
model_loaded = AutoModelForCausalLM.from_pretrained(
"quant_model", torch_dtype=torch.float16, device_map=torch_device, low_cpu_mem_usage=True
)
with torch.no_grad():
logits_loaded = model_loaded.forward(input_tensor).logits
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)