From 050636518a8b19edb36eec76c9b7676571a115a5 Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Tue, 14 Jan 2025 11:37:37 +0100 Subject: [PATCH] Fix : HQQ config when hqq not available (#35655) * fix * make style * adding require_hqq * make style --- src/transformers/testing_utils.py | 8 ++++++++ src/transformers/utils/quantization_config.py | 4 ++++ tests/quantization/hqq/test_hqq.py | 5 +++++ 3 files changed, 17 insertions(+) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 89587d303e..aa26c2c3df 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -87,6 +87,7 @@ from .utils import ( is_gguf_available, is_grokadamw_available, is_hadamard_available, + is_hqq_available, is_ipex_available, is_jieba_available, is_jinja_available, @@ -1213,6 +1214,13 @@ def require_auto_gptq(test_case): return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) +def require_hqq(test_case): + """ + Decorator for hqq dependency + """ + return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case) + + def require_auto_awq(test_case): """ Decorator for auto_awq dependency diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 3160c3481d..1883e93deb 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -224,6 +224,10 @@ class HqqConfig(QuantizationConfigMixin): ): if is_hqq_available(): from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig + else: + raise ImportError( + "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`." + ) for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]: if deprecated_key in kwargs: diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 6d08a0f0e6..c25aada6ed 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -19,6 +19,7 @@ import unittest from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig from transformers.testing_utils import ( require_accelerate, + require_hqq, require_torch_gpu, require_torch_multi_gpu, slow, @@ -86,6 +87,7 @@ MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" @require_torch_gpu +@require_hqq class HqqConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -100,6 +102,7 @@ class HqqConfigTest(unittest.TestCase): @slow @require_torch_gpu @require_accelerate +@require_hqq class HQQTest(unittest.TestCase): def tearDown(self): cleanup() @@ -122,6 +125,7 @@ class HQQTest(unittest.TestCase): @require_torch_gpu @require_torch_multi_gpu @require_accelerate +@require_hqq class HQQTestMultiGPU(unittest.TestCase): def tearDown(self): cleanup() @@ -144,6 +148,7 @@ class HQQTestMultiGPU(unittest.TestCase): @slow @require_torch_gpu @require_accelerate +@require_hqq class HQQSerializationTest(unittest.TestCase): def tearDown(self): cleanup()