Display warning for unknown quants config instead of an error (#35963)
* add supports_quant_method check * fix * add test and fix suggestions * change logic slightly --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
@@ -3634,7 +3634,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
pre_quantized = getattr(config, "quantization_config", None) is not None
|
pre_quantized = hasattr(config, "quantization_config")
|
||||||
|
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
|
||||||
|
pre_quantized = False
|
||||||
|
|
||||||
if pre_quantized or quantization_config is not None:
|
if pre_quantized or quantization_config is not None:
|
||||||
if pre_quantized:
|
if pre_quantized:
|
||||||
config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
|
config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
|
||||||
@@ -3647,7 +3650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
config.quantization_config,
|
config.quantization_config,
|
||||||
pre_quantized=pre_quantized,
|
pre_quantized=pre_quantized,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
hf_quantizer = None
|
hf_quantizer = None
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import warnings
|
|||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from ..models.auto.configuration_auto import AutoConfig
|
from ..models.auto.configuration_auto import AutoConfig
|
||||||
|
from ..utils import logging
|
||||||
from ..utils.quantization_config import (
|
from ..utils.quantization_config import (
|
||||||
AqlmConfig,
|
AqlmConfig,
|
||||||
AwqConfig,
|
AwqConfig,
|
||||||
@@ -82,6 +83,8 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|||||||
"vptq": VptqConfig,
|
"vptq": VptqConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AutoQuantizationConfig:
|
class AutoQuantizationConfig:
|
||||||
"""
|
"""
|
||||||
@@ -195,3 +198,23 @@ class AutoHfQuantizer:
|
|||||||
warnings.warn(warning_msg)
|
warnings.warn(warning_msg)
|
||||||
|
|
||||||
return quantization_config
|
return quantization_config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def supports_quant_method(quantization_config_dict):
|
||||||
|
quant_method = quantization_config_dict.get("quant_method", None)
|
||||||
|
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
|
||||||
|
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
|
||||||
|
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
|
||||||
|
elif quant_method is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
|
||||||
|
logger.warning(
|
||||||
|
f"Unknown quantization type, got {quant_method} - supported types are:"
|
||||||
|
f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
|
||||||
|
"To remove the warning, you can delete the quantization_config attribute in config.json"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|||||||
@@ -1819,6 +1819,19 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertIsNone(model_outputs.past_key_values)
|
self.assertIsNone(model_outputs.past_key_values)
|
||||||
self.assertTrue(model.training)
|
self.assertTrue(model.training)
|
||||||
|
|
||||||
|
def test_unknown_quantization_config(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = BertModel(config)
|
||||||
|
config.quantization_config = {"quant_method": "unknown"}
|
||||||
|
model.save_pretrained(tmpdir)
|
||||||
|
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||||
|
BertModel.from_pretrained(tmpdir)
|
||||||
|
self.assertEqual(len(cm.records), 1)
|
||||||
|
self.assertTrue(cm.records[0].message.startswith("Unknown quantization type, got"))
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user