Fix : HQQ config when hqq not available (#35655)

* fix

* make style

* adding require_hqq

* make style
This commit is contained in:
Mohamed Mekkouri
2025-01-14 11:37:37 +01:00
committed by GitHub
parent 715fdd6459
commit 050636518a
3 changed files with 17 additions and 0 deletions

View File

@@ -87,6 +87,7 @@ from .utils import (
is_gguf_available, is_gguf_available,
is_grokadamw_available, is_grokadamw_available,
is_hadamard_available, is_hadamard_available,
is_hqq_available,
is_ipex_available, is_ipex_available,
is_jieba_available, is_jieba_available,
is_jinja_available, is_jinja_available,
@@ -1213,6 +1214,13 @@ def require_auto_gptq(test_case):
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
def require_hqq(test_case):
"""
Decorator for hqq dependency
"""
return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case)
def require_auto_awq(test_case): def require_auto_awq(test_case):
""" """
Decorator for auto_awq dependency Decorator for auto_awq dependency

View File

@@ -224,6 +224,10 @@ class HqqConfig(QuantizationConfigMixin):
): ):
if is_hqq_available(): if is_hqq_available():
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
else:
raise ImportError(
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
)
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]: for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
if deprecated_key in kwargs: if deprecated_key in kwargs:

View File

@@ -19,6 +19,7 @@ import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
require_accelerate, require_accelerate,
require_hqq,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
slow, slow,
@@ -86,6 +87,7 @@ MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@require_torch_gpu @require_torch_gpu
@require_hqq
class HqqConfigTest(unittest.TestCase): class HqqConfigTest(unittest.TestCase):
def test_to_dict(self): def test_to_dict(self):
""" """
@@ -100,6 +102,7 @@ class HqqConfigTest(unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_gpu
@require_accelerate @require_accelerate
@require_hqq
class HQQTest(unittest.TestCase): class HQQTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
cleanup() cleanup()
@@ -122,6 +125,7 @@ class HQQTest(unittest.TestCase):
@require_torch_gpu @require_torch_gpu
@require_torch_multi_gpu @require_torch_multi_gpu
@require_accelerate @require_accelerate
@require_hqq
class HQQTestMultiGPU(unittest.TestCase): class HQQTestMultiGPU(unittest.TestCase):
def tearDown(self): def tearDown(self):
cleanup() cleanup()
@@ -144,6 +148,7 @@ class HQQTestMultiGPU(unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_gpu
@require_accelerate @require_accelerate
@require_hqq
class HQQSerializationTest(unittest.TestCase): class HQQSerializationTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
cleanup() cleanup()