Display warning for unknown quants config instead of an error (#35963)

* add supports_quant_method check

* fix

* add test and fix suggestions

* change logic slightly

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Marc Sun
2025-02-04 15:17:01 +01:00
committed by GitHub
parent f19bfa50e7
commit 9f486badd5
3 changed files with 40 additions and 2 deletions

View File

@@ -3634,7 +3634,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_kwargs = kwargs
pre_quantized = getattr(config, "quantization_config", None) is not None
pre_quantized = hasattr(config, "quantization_config")
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
pre_quantized = False
if pre_quantized or quantization_config is not None:
if pre_quantized:
config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
@@ -3647,7 +3650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config.quantization_config,
pre_quantized=pre_quantized,
)
else:
hf_quantizer = None

View File

@@ -15,6 +15,7 @@ import warnings
from typing import Dict, Optional, Union
from ..models.auto.configuration_auto import AutoConfig
from ..utils import logging
from ..utils.quantization_config import (
AqlmConfig,
AwqConfig,
@@ -82,6 +83,8 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"vptq": VptqConfig,
}
logger = logging.get_logger(__name__)
class AutoQuantizationConfig:
"""
@@ -195,3 +198,23 @@ class AutoHfQuantizer:
warnings.warn(warning_msg)
return quantization_config
@staticmethod
def supports_quant_method(quantization_config_dict):
quant_method = quantization_config_dict.get("quant_method", None)
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
elif quant_method is None:
raise ValueError(
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
)
if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
logger.warning(
f"Unknown quantization type, got {quant_method} - supported types are:"
f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
"To remove the warning, you can delete the quantization_config attribute in config.json"
)
return False
return True

View File

@@ -1819,6 +1819,19 @@ class ModelUtilsTest(TestCasePlus):
self.assertIsNone(model_outputs.past_key_values)
self.assertTrue(model.training)
def test_unknown_quantization_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
config.quantization_config = {"quant_method": "unknown"}
model.save_pretrained(tmpdir)
with self.assertLogs("transformers", level="WARNING") as cm:
BertModel.from_pretrained(tmpdir)
self.assertEqual(len(cm.records), 1)
self.assertTrue(cm.records[0].message.startswith("Unknown quantization type, got"))
@slow
@require_torch