Fix HQQ model param device transfer issue (#38466)
* Fix HQQ model param device transfer issue * modify a comment * clear the code and add test for hqq device/dtype * fix test hqq code quality of imports --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import accelerate
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
@@ -119,6 +121,41 @@ class HQQTest(unittest.TestCase):
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
def test_quantized_model_to_new_device_and_new_dtype(self):
|
||||
"""
|
||||
Simple LLM model testing different devices and dtypes
|
||||
"""
|
||||
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
original_device = hqq_runner.model.model.layers[0].self_attn.v_proj.device
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
# Remove `accelerate` hooks to enable move the model to a new device
|
||||
accelerate.hooks.remove_hook_from_module(hqq_runner.model, recurse=True)
|
||||
|
||||
hqq_runner.model.to("cpu", torch.bfloat16)
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
hqq_runner.model.cuda(original_device)
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
def test_quantized_model_fake_weight_dtype(self):
|
||||
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
# We use a hack to inject a fake weight to HQQLinear. Check that it works
|
||||
self.assertEqual(hqq_runner.model.model.layers[0].self_attn.v_proj.weight.dtype, torch.float16)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user