From 9f486badd584756d60479f2fe257d9b8e8c761b9 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 4 Feb 2025 15:17:01 +0100 Subject: [PATCH] Display warning for unknown quants config instead of an error (#35963) * add supports_quant_method check * fix * add test and fix suggestions * change logic slightly --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --- src/transformers/modeling_utils.py | 6 ++++-- src/transformers/quantizers/auto.py | 23 +++++++++++++++++++++++ tests/utils/test_modeling_utils.py | 13 +++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b31368c606..1c67ee1f89 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3634,7 +3634,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model_kwargs = kwargs - pre_quantized = getattr(config, "quantization_config", None) is not None + pre_quantized = hasattr(config, "quantization_config") + if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config): + pre_quantized = False + if pre_quantized or quantization_config is not None: if pre_quantized: config.quantization_config = AutoHfQuantizer.merge_quantization_configs( @@ -3647,7 +3650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix config.quantization_config, pre_quantized=pre_quantized, ) - else: hf_quantizer = None diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index d5b51d038a..cdb569dd1c 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -15,6 +15,7 @@ import warnings from typing import Dict, Optional, Union from ..models.auto.configuration_auto import AutoConfig +from ..utils import logging from ..utils.quantization_config import ( AqlmConfig, AwqConfig, @@ -82,6 +83,8 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = { "vptq": VptqConfig, } +logger = logging.get_logger(__name__) + class AutoQuantizationConfig: """ @@ -195,3 +198,23 @@ class AutoHfQuantizer: warnings.warn(warning_msg) return quantization_config + + @staticmethod + def supports_quant_method(quantization_config_dict): + quant_method = quantization_config_dict.get("quant_method", None) + if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False): + suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit" + quant_method = QuantizationMethod.BITS_AND_BYTES + suffix + elif quant_method is None: + raise ValueError( + "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized" + ) + + if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + logger.warning( + f"Unknown quantization type, got {quant_method} - supported types are:" + f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. " + "To remove the warning, you can delete the quantization_config attribute in config.json" + ) + return False + return True diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index dd52927a25..2179c4be57 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1819,6 +1819,19 @@ class ModelUtilsTest(TestCasePlus): self.assertIsNone(model_outputs.past_key_values) self.assertTrue(model.training) + def test_unknown_quantization_config(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + model = BertModel(config) + config.quantization_config = {"quant_method": "unknown"} + model.save_pretrained(tmpdir) + with self.assertLogs("transformers", level="WARNING") as cm: + BertModel.from_pretrained(tmpdir) + self.assertEqual(len(cm.records), 1) + self.assertTrue(cm.records[0].message.startswith("Unknown quantization type, got")) + @slow @require_torch