Fix : HQQ config when hqq not available (#35655)
* fix * make style * adding require_hqq * make style
This commit is contained in:
@@ -87,6 +87,7 @@ from .utils import (
|
|||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
is_grokadamw_available,
|
is_grokadamw_available,
|
||||||
is_hadamard_available,
|
is_hadamard_available,
|
||||||
|
is_hqq_available,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_jieba_available,
|
is_jieba_available,
|
||||||
is_jinja_available,
|
is_jinja_available,
|
||||||
@@ -1213,6 +1214,13 @@ def require_auto_gptq(test_case):
|
|||||||
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
|
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_hqq(test_case):
|
||||||
|
"""
|
||||||
|
Decorator for hqq dependency
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_auto_awq(test_case):
|
def require_auto_awq(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator for auto_awq dependency
|
Decorator for auto_awq dependency
|
||||||
|
|||||||
@@ -224,6 +224,10 @@ class HqqConfig(QuantizationConfigMixin):
|
|||||||
):
|
):
|
||||||
if is_hqq_available():
|
if is_hqq_available():
|
||||||
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
|
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
|
||||||
|
)
|
||||||
|
|
||||||
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
|
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
|
||||||
if deprecated_key in kwargs:
|
if deprecated_key in kwargs:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import unittest
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
|
require_hqq,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
@@ -86,6 +87,7 @@ MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
@require_hqq
|
||||||
class HqqConfigTest(unittest.TestCase):
|
class HqqConfigTest(unittest.TestCase):
|
||||||
def test_to_dict(self):
|
def test_to_dict(self):
|
||||||
"""
|
"""
|
||||||
@@ -100,6 +102,7 @@ class HqqConfigTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
|
@require_hqq
|
||||||
class HQQTest(unittest.TestCase):
|
class HQQTest(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
cleanup()
|
cleanup()
|
||||||
@@ -122,6 +125,7 @@ class HQQTest(unittest.TestCase):
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
|
@require_hqq
|
||||||
class HQQTestMultiGPU(unittest.TestCase):
|
class HQQTestMultiGPU(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
cleanup()
|
cleanup()
|
||||||
@@ -144,6 +148,7 @@ class HQQTestMultiGPU(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
|
@require_hqq
|
||||||
class HQQSerializationTest(unittest.TestCase):
|
class HQQSerializationTest(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|||||||
Reference in New Issue
Block a user