Quantization / HQQ: Fix HQQ tests on our runner (#30668)

Update test_hqq.py
This commit is contained in:
Younes Belkada
2024-05-06 11:33:52 +02:00
committed by GitHub
parent a45c514899
commit 9c772ac888

View File

@@ -35,7 +35,7 @@ if is_hqq_available():
class HQQLLMRunner: class HQQLLMRunner:
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir): def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir=None):
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
@@ -118,7 +118,7 @@ class HQQTest(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model) check_forward(self, hqq_runner.model)
def test_bfp16_quantized_model_with_offloading(self): def test_f16_quantized_model_with_offloading(self):
""" """
Simple LLM model testing bfp16 with meta-data offloading Simple LLM model testing bfp16 with meta-data offloading
""" """
@@ -137,7 +137,7 @@ class HQQTest(unittest.TestCase):
) )
hqq_runner = HQQLLMRunner( hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
) )
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)