Fix HQQ model param device transfer issue (#38466)

* Fix HQQ model param device transfer issue

* modify a comment

* clear the code and add test for hqq device/dtype

* fix test hqq code quality of imports

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
艾梦
2025-06-18 21:09:00 +08:00
committed by GitHub
parent c77bcd889f
commit cb0f604192
3 changed files with 85 additions and 4 deletions

View File

@@ -15,6 +15,8 @@
import gc
import unittest
import accelerate
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
from transformers.testing_utils import (
backend_empty_cache,
@@ -119,6 +121,41 @@ class HQQTest(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_quantized_model_to_new_device_and_new_dtype(self):
"""
Simple LLM model testing different devices and dtypes
"""
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
original_device = hqq_runner.model.model.layers[0].self_attn.v_proj.device
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
# Remove `accelerate` hooks to enable move the model to a new device
accelerate.hooks.remove_hook_from_module(hqq_runner.model, recurse=True)
hqq_runner.model.to("cpu", torch.bfloat16)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
hqq_runner.model.cuda(original_device)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_quantized_model_fake_weight_dtype(self):
quant_config = HqqConfig(nbits=8, group_size=64)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
# We use a hack to inject a fake weight to HQQLinear. Check that it works
self.assertEqual(hqq_runner.model.model.layers[0].self_attn.v_proj.weight.dtype, torch.float16)
@slow
@require_torch_gpu