Display warning for unknown quants config instead of an error (#35963)
* add supports_quant_method check * fix * add test and fix suggestions * change logic slightly --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
@@ -3634,7 +3634,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
model_kwargs = kwargs
|
||||
|
||||
pre_quantized = getattr(config, "quantization_config", None) is not None
|
||||
pre_quantized = hasattr(config, "quantization_config")
|
||||
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
|
||||
pre_quantized = False
|
||||
|
||||
if pre_quantized or quantization_config is not None:
|
||||
if pre_quantized:
|
||||
config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
|
||||
@@ -3647,7 +3650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config.quantization_config,
|
||||
pre_quantized=pre_quantized,
|
||||
)
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import warnings
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..utils import logging
|
||||
from ..utils.quantization_config import (
|
||||
AqlmConfig,
|
||||
AwqConfig,
|
||||
@@ -82,6 +83,8 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"vptq": VptqConfig,
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AutoQuantizationConfig:
|
||||
"""
|
||||
@@ -195,3 +198,23 @@ class AutoHfQuantizer:
|
||||
warnings.warn(warning_msg)
|
||||
|
||||
return quantization_config
|
||||
|
||||
@staticmethod
|
||||
def supports_quant_method(quantization_config_dict):
|
||||
quant_method = quantization_config_dict.get("quant_method", None)
|
||||
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
|
||||
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
|
||||
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
|
||||
elif quant_method is None:
|
||||
raise ValueError(
|
||||
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
|
||||
)
|
||||
|
||||
if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
|
||||
logger.warning(
|
||||
f"Unknown quantization type, got {quant_method} - supported types are:"
|
||||
f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
|
||||
"To remove the warning, you can delete the quantization_config attribute in config.json"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -1819,6 +1819,19 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertIsNone(model_outputs.past_key_values)
|
||||
self.assertTrue(model.training)
|
||||
|
||||
def test_unknown_quantization_config(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = BertModel(config)
|
||||
config.quantization_config = {"quant_method": "unknown"}
|
||||
model.save_pretrained(tmpdir)
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
BertModel.from_pretrained(tmpdir)
|
||||
self.assertEqual(len(cm.records), 1)
|
||||
self.assertTrue(cm.records[0].message.startswith("Unknown quantization type, got"))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user