diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 4758529c6b..0f33a5b6b7 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -137,6 +137,8 @@
title: Overview
- local: quantization
title: Quantization
+ - local: hf_quantizer
+ title: Contribute new quantization method
- sections:
- local: perf_train_gpu_one
title: Methods and tools for efficient training on a single GPU
diff --git a/docs/source/en/hf_quantizer.md b/docs/source/en/hf_quantizer.md
new file mode 100644
index 0000000000..91a0136213
--- /dev/null
+++ b/docs/source/en/hf_quantizer.md
@@ -0,0 +1,70 @@
+
+
+# Contribute new quantization method
+
+Transformers supports and integrates many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are other quantization approaches that are not yet integrated. To make adding and using these quantization methods with Transformers models easier, you should use the [`HfQuantizer`] class. The [`HfQuantizer`] is designed as an internal helper class for adding a quantization method instead of something you apply to every PyTorch module.
+
+This guide will show you how to integrate a new quantization method with the [`HfQuantizer`] class.
+
+
+## Requirements
+
+Before integrating a new quantization method into Transformers, ensure the method you are trying to add meets the following prerequisites. Only quantization methods that can be run with PyTorch modules are currently supported.
+
+- The quantization method is available through a Python package that is pip-installable by anyone (it is also fine if you can only install the package from source). Ideally, pre-compiled kernels are included in the pip package.
+- The method can run on commonly-used hardware (CPU, GPU, ...).
+- The method is wrapped in a `nn.Module` (e.g., `Linear8bitLt`, `Linear4bit`), and the quantized linear layer should have the following definition:
+
+```py
+class Linear4bit(nn.Module):
+ def __init__(self, ...):
+ ...
+
+ def forward(self, x):
+ return my_4bit_kernel(x, self.weight, self.bias)
+```
+This way, Transformers models can be easily quantized by replacing some instances of `nn.Linear` with a target class.
+- The quantization method should be serializable. You can save the quantized weights locally or push them to the Hub.
+- Make sure the package that contains the quantization kernels/primitive is stable (no frequent breaking changes).
+
+For some quantization methods, they may require "pre-quantizing" the models through data calibration (e.g., AWQ). In this case, we prefer to only support inference in Transformers and let the third-party library maintained by the ML community deal with the model quantization itself.
+
+## Build a new HFQuantizer class
+
+1. 📕 Create a new quantization config class inside `src/transformers/utils/quantization_config.py` and make sure to expose the new quantization config inside Transformers main `init` by adding it to the `_import_structure` object of `src/transformers/__init__.py`.
+
+2- 🗃 Create a new file inside `src/transformers/quantizers/` named `quantizer_your_method.py`, and make it inherit from `src/transformers/quantizers/base.py::HfQuantizer`. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in `src/transformers/quantizers/auto.py`
+
+3- 🔩 Define the following class attributes/property methods for your quantization method:
+
+* `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
+* `required_packages`: A list of strings of the required packages to use the quantized weights. You might need to define some new utility methods such as `is_auto_awq_available` in `transformers/src/utils/import_utils.py`.
+* `requires_parameters_quantization`: Only required if your quantization method requires extra attention to the underlying `nn.Parameter` object. For example, bitsandbytes uses `Params4bit` and `Int8Param`, which requires some extra attention when quantizing the model. Most of the recent quantization method packs int2/int4 weights inside `torch.uint8` weights, so this flag should not be really required (set to `False` by default).
+* `is_serializable`: A property method to determine whether the method is serializable or not.
+* `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).
+
+
+4- 🪛 Write the `validate_environment` and `update_torch_dtype` methods. These methods are called before creating the quantized model to ensure users use the right configuration. You can have a look at how this is done on other quantizers.
+
+5- 🖋 Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules (e.g., `nn.Linear`) with the target modules (quantization modules). You can define a module replacement logic or any other utility method by creating a new file in `transformers/src/integrations/` and exposing the relevant methods in that folder's `__init__.py` file. The best starting point would be to have a look at another quantization methods such as `quantizer_awq.py`
+
+6- 🖊 Write the `_process_model_after_weight_loading` method. This method enables implementing additional features that require manipulating the model after loading the weights.
+
+7- 📖 Document everything! Make sure your quantization method is documented in the `docs/source/en/quantization.md` file.
+
+8- 🟢 Add tests! You should add tests by first adding the package in our nightly Dockerfile inside `docker/transformers-all-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out how it is implemented for other quantization methods.
+
diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md
index 3a1c542c0b..c56e69b08e 100644
--- a/docs/source/en/quantization.md
+++ b/docs/source/en/quantization.md
@@ -20,6 +20,12 @@ Quantization techniques focus on representing data with less information while a
Transformers supports several quantization schemes to help you run inference with large language models (LLMs) and finetune adapters on quantized models. This guide will show you how to use Activation-aware Weight Quantization (AWQ), AutoGPTQ, and bitsandbytes.
+
+
+Interested in adding a new quantization method to Transformers? Read the [HfQuantizer](./hf_quantizer) guide to learn how!
+
+
+
## AWQ
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 39cb0eabe6..d6bd14fccd 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1001,6 +1001,7 @@ _import_structure = {
"pipeline",
],
"processing_utils": ["ProcessorMixin"],
+ "quantizers": [],
"testing_utils": [],
"tokenization_utils": ["PreTrainedTokenizer"],
"tokenization_utils_base": [
diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py
index 8ef9a7ec96..dd8578ef60 100644
--- a/src/transformers/integrations/awq.py
+++ b/src/transformers/integrations/awq.py
@@ -187,17 +187,18 @@ def fuse_awq_modules(model, quantization_config):
Args:
model (`~PreTrainedModel`):
The model to fuse - note this model should have been converted into AWQ format beforehand.
- quantization_config (`dict`):
+ quantization_config (`Union[AwqConfig, dict]`):
The quantization configuration to use.
"""
# We need to convert it from dict in order to get an AwqConfig object
# otherwise the fields `backend` etc. will not be available
# https://github.com/huggingface/transformers/pull/27411#discussion_r1414044495
- awq_config = AwqConfig.from_dict(quantization_config)
- backend = awq_config.backend
+ if isinstance(quantization_config, dict):
+ quantization_config = AwqConfig.from_dict(quantization_config)
+ backend = quantization_config.backend
- modules_to_fuse = get_modules_to_fuse(model, awq_config)
- modules_to_not_convert = getattr(awq_config, "modules_to_not_convert", None)
+ modules_to_fuse = get_modules_to_fuse(model, quantization_config)
+ modules_to_not_convert = getattr(quantization_config, "modules_to_not_convert", None)
if backend == AwqBackendPackingMethod.AUTOAWQ:
from awq.modules.fused.attn import QuantAttentionFused
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index b899921ad0..15855ceb7e 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -53,6 +53,7 @@ from .pytorch_utils import ( # noqa: F401
prune_layer,
prune_linear_layer,
)
+from .quantizers import AutoHfQuantizer, HfQuantizer
from .safetensors_conversion import auto_conversion
from .utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
@@ -75,8 +76,6 @@ from .utils import (
extract_commit_hash,
has_file,
is_accelerate_available,
- is_auto_awq_available,
- is_auto_gptq_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
is_offline_mode,
@@ -97,7 +96,7 @@ from .utils.import_utils import (
is_torch_fx_proxy,
is_torchdynamo_compiling,
)
-from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantizationMethod
+from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
@@ -693,7 +692,7 @@ def _load_state_dict_into_meta_model(
state_dict_folder=None,
state_dict_index=None,
dtype=None,
- is_quantized=False,
+ hf_quantizer=None,
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
@@ -715,9 +714,6 @@ def _load_state_dict_into_meta_model(
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
- if is_quantized:
- from .integrations import set_module_quantized_tensor_to_device
-
error_msgs = []
old_keys = []
@@ -799,44 +795,17 @@ def _load_state_dict_into_meta_model(
if not is_safetensors:
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
- state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
- elif not is_quantized:
- # For backward compatibility with older versions of `accelerate`
+ state_dict_index = offload_weight(param, param_name, model, state_dict_folder, state_dict_index)
+ elif (
+ hf_quantizer is None
+ or (not hf_quantizer.requires_parameters_quantization)
+ or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict))
+ ):
+ # For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
- elif param.dtype in (torch.int8, torch.uint8) and is_quantized:
- # handling newly quantized weights and loaded quantized weights
- # edit the param.dtype restrictions and is_quantized condition when adding new quant methods
- quantized_stats = {}
-
- if (param_name + ".quant_state.bitsandbytes__fp4" in state_dict) or (
- param_name + ".quant_state.bitsandbytes__nf4" in state_dict
- ):
- # 4bit loading. Collecting components for restoring quantized weight
- # This can be expanded to make a universal call for any quantized weight loading
- for k, v in state_dict.items():
- if param_name + "." in k:
- quantized_stats[k] = v
- unexpected_keys.remove(k)
-
- set_module_quantized_tensor_to_device(
- model, param_name, param_device, value=param, quantized_stats=quantized_stats
- )
-
- elif param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
- # 8bit loading. Could be combined with the above 4bit call.
- # condition looks unreliable
- fp16_statistics_key = param_name.replace("weight", "SCB")
- unexpected_keys.remove(fp16_statistics_key)
- set_module_quantized_tensor_to_device(
- model,
- param_name,
- param_device,
- value=param,
- quantized_stats={"SCB": state_dict[fp16_statistics_key]},
- )
else:
- # loading not quantized params in quantized model
- set_module_quantized_tensor_to_device(model, param_name, param_device, value=param)
+ hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
+ # TODO: consider removing used param_parts from state_dict before return
return error_msgs, offload_index, state_dict_index
@@ -1086,6 +1055,7 @@ class ModuleUtilsMixin:
total_numel = []
is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
+
if is_loaded_in_4bit:
if is_bitsandbytes_available():
import bitsandbytes as bnb
@@ -2283,30 +2253,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
- # Checks if the model has been loaded in 8-bit
- if (
- getattr(self, "is_loaded_in_8bit", False)
- and not getattr(self, "is_8bit_serializable", False)
- and not _hf_peft_config_loaded
- ):
- raise NotImplementedError(
- "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
- "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
- )
+ hf_quantizer = getattr(self, "hf_quantizer", None)
+ quantization_serializable = (
+ hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable
+ )
- if (
- getattr(self, "is_loaded_in_4bit", False)
- and not getattr(self, "is_4bit_serializable", False)
- and not _hf_peft_config_loaded
- ):
- raise NotImplementedError(
- "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
- "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
+ if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
+ raise ValueError(
+ f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
+ " the logger on the traceback to understand the reason why the quantized model is not serializable."
)
- if getattr(self, "_awq_is_fused", False):
- raise ValueError("You cannot save an AWQ model that uses fused modules!")
-
if "save_config" in kwargs:
warnings.warn(
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
@@ -2788,15 +2745,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
- load_in_8bit (`bool`, *optional*, defaults to `False`):
- If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please
- install `bitsandbytes` (`pip install -U bitsandbytes`).
- load_in_4bit (`bool`, *optional*, defaults to `False`):
- If `True`, will convert the loaded model into 4bit precision quantized model. To use this feature
- install the latest version of `bitsandbytes` (`pip install -U bitsandbytes`).
quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*):
A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g
- bitsandbytes, gptq)
+ bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and
+ `load_in_8bit`, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes
+ quantizations and not preferred. consider inserting all such arguments into quantization_config
+ instead.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
@@ -2910,14 +2864,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if use_safetensors is None and not is_safetensors_available():
use_safetensors = False
-
- if is_bitsandbytes_available():
- is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
- is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")
- else:
- is_4bit_serializable = False
- is_8bit_serializable = False
-
if trust_remote_code is True:
logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
@@ -3002,69 +2948,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)
- quantization_method_from_args = None
-
- if quantization_config is not None:
- quantization_method_from_args = getattr(
- quantization_config, "quant_method", QuantizationMethod.BITS_AND_BYTES
- )
-
- if quantization_config is None and (load_in_8bit or load_in_4bit):
- quantization_method_from_args = QuantizationMethod.BITS_AND_BYTES
- quantization_config, kwargs = BitsAndBytesConfig.from_dict(
- config_dict={"load_in_8bit": load_in_8bit, "load_in_4bit": load_in_4bit},
- return_unused_kwargs=True,
- **kwargs,
- )
- elif quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES:
- load_in_8bit = quantization_config.load_in_8bit
- load_in_4bit = quantization_config.load_in_4bit
-
- quantization_config_kwargs = {
- k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters
- }
-
- if len(quantization_config_kwargs) > 0:
+ # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
+ if load_in_4bit or load_in_8bit:
+ if quantization_config is not None:
raise ValueError(
- "You can't pass `load_in_8bit` or any other `BitsAndBytesConfig` argument as a kwarg when passing "
+ "You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing "
"`quantization_config` argument at the same time."
)
- if load_in_8bit or load_in_4bit:
- if not torch.cuda.is_available():
- raise RuntimeError("No GPU found. A GPU is needed for quantization.")
- if not (is_accelerate_available() and is_bitsandbytes_available()):
- raise ImportError(
- "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
- " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
- " `pip install bitsandbytes`."
- )
+ # preparing BitsAndBytesConfig from kwargs
+ config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters}
+ config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit}
+ quantization_config, kwargs = BitsAndBytesConfig.from_dict(
+ config_dict=config_dict, return_unused_kwargs=True, **kwargs
+ )
+ logger.warning(
+ "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. "
+ "Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead."
+ )
- if torch_dtype is None:
- # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
- logger.info(
- f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
- "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
- "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
- " torch_dtype=torch.float16 to remove this warning."
- )
- torch_dtype = torch.float16
-
- if device_map is None:
- device_map = {"": torch.cuda.current_device()}
- logger.info(
- "The device_map was not initialized. "
- "Setting device_map to {'':torch.cuda.current_device()}. "
- "If you want to use the model for inference, please set device_map ='auto' "
- )
- if low_cpu_mem_usage is None:
- low_cpu_mem_usage = True
-
- if from_tf or from_flax:
- raise ValueError(
- "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
- " sure the weights are in PyTorch format."
- )
+ from_pt = not (from_tf | from_flax)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None:
@@ -3107,155 +3010,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
- quantizer = None
- quantization_method_from_config = None
- if hasattr(config, "quantization_config"):
- quantization_method_from_config = config.quantization_config.get(
- "quant_method", QuantizationMethod.BITS_AND_BYTES
- )
-
- if (
- quantization_method_from_args is not None
- and quantization_method_from_args == QuantizationMethod.AWQ
- and quantization_method_from_config is None
- ):
- raise ValueError(
- "You cannot quantize with AWQ a non-quantized model using transformers, please refer to the quantization documentation"
- " to read more about how to quantize models with AWQ algorithm https://huggingface.co/docs/transformers/main_classes/quantization"
- )
-
- if quantization_method_from_config is not None and quantization_method_from_args is not None:
- if quantization_method_from_config != quantization_method_from_args:
- raise ValueError(
- f"The model is already quantized with {quantization_method_from_config}. "
- f"You can't quantize it again with {quantization_method_from_args}"
- )
-
- if (
- quantization_method_from_config in (QuantizationMethod.GPTQ, QuantizationMethod.AWQ)
- and quantization_method_from_args is not None
- ):
- loading_attr_dict = quantization_config.get_loading_attributes()
- for attr, val in loading_attr_dict.items():
- config.quantization_config[attr] = val
- quantization_method_from_args = None
- logger.warning(
- f"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a "
- f"`quantization_config` attribute and has already quantized weights. However, loading attributes"
- f" (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
- )
- if (
- quantization_method_from_args == QuantizationMethod.GPTQ
- or quantization_method_from_config == QuantizationMethod.GPTQ
- ):
- gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
- if not gptq_supports_cpu and not torch.cuda.is_available():
- raise RuntimeError("GPU is required to quantize or run quantize model.")
- elif not (is_optimum_available() and is_auto_gptq_available()):
- raise ImportError(
- "Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq library (`pip install auto-gptq`)"
- )
- elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
- raise ImportError(
- "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`"
+ pre_quantized = getattr(config, "quantization_config", None) is not None
+ if pre_quantized or quantization_config is not None:
+ if pre_quantized:
+ config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
+ config.quantization_config, quantization_config
)
else:
- # Need to protect the import
- from optimum.gptq import GPTQQuantizer
- if quantization_method_from_config == QuantizationMethod.GPTQ:
- quantization_config = GPTQConfig.from_dict(config.quantization_config)
config.quantization_config = quantization_config
- if torch_dtype is None:
- torch_dtype = torch.float16
- else:
- logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
- quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict_optimum())
- elif quantization_method_from_config == QuantizationMethod.AWQ:
- if not torch.cuda.is_available():
- raise RuntimeError("GPU is required to run AWQ quantized model.")
+ hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized)
+ else:
+ hf_quantizer = None
- if not is_auto_awq_available():
- raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")
-
- if not is_accelerate_available():
- raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
-
- if device_map is None:
- logger.warning(
- "You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
- "your model on a GPU device in order to run your model."
- )
- elif device_map is not None:
- if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
- raise ValueError(
- "You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
- " This is not supported. Please remove the CPU or disk device from the device_map."
- )
-
- if torch_dtype is None:
- torch_dtype = torch.float16
- else:
- logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
+ if hf_quantizer is not None:
+ hf_quantizer.validate_environment(
+ torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map
+ )
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
+ device_map = hf_quantizer.update_device_map(device_map)
# Force-set to `True` for more mem efficiency
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
-
- if quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and (
- (is_8bit_serializable and load_in_8bit) or (is_4bit_serializable and load_in_4bit)
- ):
- if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
- logger.warning(
- "You passed `quantization_config` to `from_pretrained` but the model you're loading already has a"
- " `quantization_config` attribute. The `quantization_config` attribute will be overwritten with the"
- " one you passed to `from_pretrained`."
- )
- config.quantization_config = quantization_config
- elif (
- (is_8bit_serializable or is_4bit_serializable)
- and not (load_in_8bit or load_in_4bit)
- and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
- ):
- quantization_config = config.quantization_config
- if isinstance(quantization_config, dict):
- quantization_config = BitsAndBytesConfig.from_dict(quantization_config, return_unused_kwargs=False)
- elif isinstance(quantization_config, BitsAndBytesConfig):
- pass
- else:
- raise ValueError(
- f"Invalid type for `quantization_config`: {type(quantization_config)}. Should be a `dict` or a"
- " `BitsAndBytesConfig` instance."
- )
-
- load_in_8bit = quantization_config.load_in_8bit
- load_in_4bit = quantization_config.load_in_4bit
-
- if load_in_8bit or load_in_4bit:
- if torch_dtype is None:
- torch_dtype = torch.float16
- if device_map is None:
- if torch.cuda.is_available():
- device_map = {"": torch.cuda.current_device()}
- else:
- raise RuntimeError("No GPU found. A GPU is needed for quantization.")
- logger.info(
- "The device_map was not initialized. "
- "Setting device_map to {'':torch.cuda.current_device()}. "
- "If you want to use the model for inference, please set device_map ='auto' "
- )
- if low_cpu_mem_usage is None:
- low_cpu_mem_usage = True
-
- elif (
- not is_8bit_serializable
- and not (load_in_8bit or load_in_4bit)
- and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
- ):
- logger.warning(
- "Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
- " `bitsandbytes` version to support 4 and 8 bit serialization. Please install the latest version of `bitsandbytes` with "
- " `pip install --upgrade bitsandbytes`."
- )
+ logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.")
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
@@ -3564,7 +3341,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
- torch_dtype == torch.float16 or load_in_4bit or load_in_8bit
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if is_sharded:
@@ -3587,7 +3364,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
- elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
+ elif low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
@@ -3610,106 +3387,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
keep_in_fp32_modules = []
- if load_in_8bit or load_in_4bit:
- from .integrations import get_keys_to_not_convert, replace_with_bnb_linear
-
- llm_int8_skip_modules = quantization_config.llm_int8_skip_modules
- load_in_8bit_fp32_cpu_offload = quantization_config.llm_int8_enable_fp32_cpu_offload
- if load_in_8bit:
- logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
- else:
- logger.info("Detected 4-bit loading: activating 4-bit loading for this model")
-
- # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
- if llm_int8_skip_modules is None:
- modules_to_not_convert = get_keys_to_not_convert(model)
- else:
- modules_to_not_convert = llm_int8_skip_modules
-
- if not isinstance(modules_to_not_convert, list):
- modules_to_not_convert = [modules_to_not_convert]
-
- modules_to_not_convert.extend(keep_in_fp32_modules)
-
- # Extend the modules to not convert to keys that are supposed to be offloaded to `cpu` or `disk`
- if isinstance(device_map, dict) and len(device_map.keys()) > 1:
- keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
-
- if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
- raise ValueError(
- "If you want to offload some keys to `cpu` or `disk`, you need to set "
- "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
- " converted to 8-bit but kept in 32-bit."
- )
-
- modules_to_not_convert.extend(keys_on_cpu)
-
- supports_4bit = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.39.0")
-
- if load_in_4bit and not supports_4bit:
- raise ValueError(
- "You have a version of `bitsandbytes` that is not compatible with 4bit inference and training"
- " make sure you have the latest version of `bitsandbytes` installed"
- )
-
- model = replace_with_bnb_linear(
- model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config
+ if hf_quantizer is not None:
+ hf_quantizer.preprocess_model(
+ model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
)
- # training in 8-bit is only available in 0.37.0+
- model._is_quantized_training_enabled = version.parse(
- importlib.metadata.version("bitsandbytes")
- ) >= version.parse("0.37.0")
-
- config.quantization_config = quantization_config
- model.is_8bit_serializable = is_8bit_serializable
- model.is_4bit_serializable = is_4bit_serializable
-
- if load_in_8bit and torch_dtype is None:
- logger.warning(
- "You are loading your model in 8bit but you did not specify a `torch_dtype` attribute. "
- "All non-linear modules will be loaded in full precision."
- " If you want to load the other modules in other precision, please specify a `torch_dtype` attribute."
- )
- if quantization_method_from_config == QuantizationMethod.GPTQ:
- model = quantizer.convert_model(model)
- model._is_quantized_training_enabled = True
- elif quantization_method_from_config == QuantizationMethod.AWQ:
- from .integrations import fuse_awq_modules, get_keys_to_not_convert, replace_with_awq_linear
-
- modules_to_not_convert = get_keys_to_not_convert(model)
-
- if quantization_config is None:
- quantization_config = AwqConfig.from_dict(config.quantization_config)
- # In case a user passes a `AwqConfig` with `do_fuse=True` for models that have
- # a `modules_to_not_convert` attribute we need to manually set that attribute into the
- # passed `quantization_config`
- elif (
- getattr(quantization_config, "modules_to_not_convert", None) is None
- and "modules_to_not_convert" in config.quantization_config
- ):
- quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"]
-
- if quantization_config.modules_to_not_convert is not None:
- modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
-
- model, has_been_replaced = replace_with_awq_linear(
- model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
- )
- model._is_quantized_training_enabled = False
-
- if not has_been_replaced:
- logger.warning(
- "You are loading an AWQ model but no linear modules were found in your model."
- " Please double check your model architecture, or submit an issue on github if you think this is"
- " a bug."
- )
-
- if quantization_method_from_config is not None:
- model.quantization_method = quantization_method_from_config
- elif quantization_method_from_args is not None:
- model.quantization_method = quantization_method_from_args
- if hasattr(model, "quantization_method"):
- model.is_quantized = True
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
@@ -3719,14 +3400,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if isinstance(device_map, str):
special_dtypes = {}
- if load_in_8bit or load_in_4bit:
- special_dtypes.update(
- {
- name: torch_dtype
- for name, _ in model.named_parameters()
- if any(m in name for m in modules_to_not_convert)
- }
- )
+
+ if hf_quantizer is not None:
+ special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
special_dtypes.update(
{
@@ -3738,20 +3414,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
target_dtype = torch_dtype
- if load_in_4bit:
- if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
- from accelerate.utils import CustomDtype
-
- target_dtype = CustomDtype.INT4
- else:
- raise ValueError(
- "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute"
- " the appropriate device map, you should upgrade your `accelerate` library, "
- "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map "
- "calculation. You may encounter unexpected behavior, or pass your own device map"
- )
- elif load_in_8bit:
- target_dtype = torch.int8
+ if hf_quantizer is not None:
+ target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
no_split_modules = model._get_no_split_modules(device_map)
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
@@ -3778,32 +3442,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
else:
max_memory = get_max_memory(max_memory)
- if getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
- # need more space for buffers that are created during quantization
- max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ if hf_quantizer is not None:
+ max_memory = hf_quantizer.adjust_max_memory(max_memory)
device_map_kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
- if load_in_8bit or load_in_4bit:
- # The LM head / tied weights or any last module can stay on disk / CPU
- device_map_without_lm_head = {
- key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert
- }
- if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
- raise ValueError(
- """
- Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
- the quantized model. If you want to dispatch the model on the CPU or the disk while keeping
- these modules in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom
- `device_map` to `from_pretrained`. Check
- https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
- for more details.
- """
- )
- del device_map_without_lm_head
+ if hf_quantizer is not None:
+ hf_quantizer.validate_environment(device_map=device_map)
elif device_map is not None:
model.tie_weights()
@@ -3867,13 +3515,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
- is_quantized=(getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES),
+ hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
)
- model.is_loaded_in_4bit = load_in_4bit
- model.is_loaded_in_8bit = load_in_8bit
-
# make sure token embedding weights are still tied if needed
model.tie_weights()
@@ -3903,14 +3548,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
pass
- if (
- quantization_config is not None
- and quantization_config.quant_method == QuantizationMethod.AWQ
- and quantization_config.do_fuse
- ):
- model = fuse_awq_modules(model, config.quantization_config)
- model._awq_is_fused = True
-
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
device_map_kwargs = {
@@ -3922,16 +3559,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **device_map_kwargs)
- if quantization_method_from_args == QuantizationMethod.GPTQ:
- if quantization_config.tokenizer is None:
- quantization_config.tokenizer = pretrained_model_name_or_path
- if cls.main_input_name != "input_ids":
- raise RuntimeError("We can only quantize pure text model.")
- quantizer.quantize_model(model, quantization_config.tokenizer)
- config.quantization_config = GPTQConfig.from_dict_optimum(quantizer.to_dict())
- model._is_quantized_training_enabled = True
- if quantization_method_from_config == QuantizationMethod.GPTQ:
- model = quantizer.post_init_model(model)
+ if hf_quantizer is not None:
+ hf_quantizer.postprocess_model(model)
+ model.hf_quantizer = hf_quantizer
if _adapter_model_path is not None:
model.load_adapter(
@@ -3969,12 +3599,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_folder=None,
offload_state_dict=None,
dtype=None,
- is_quantized=False,
+ hf_quantizer=None,
keep_in_fp32_modules=None,
):
is_safetensors = False
- if is_quantized:
- from .integrations import set_module_quantized_tensor_to_device
if device_map is not None and "disk" in device_map.values():
archive_file = (
@@ -4098,12 +3726,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
target_dtype = torch.float32
if param.device == torch.device("meta"):
- if not (is_quantized):
- set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype))
+ value = torch.empty(*param.size(), dtype=target_dtype)
+ if getattr(
+ hf_quantizer, "requires_parameters_quantization", False
+ ) or not hf_quantizer.check_quantized_param(
+ model, param_value=value, param_name=key, state_dict={}
+ ):
+ set_module_tensor_to_device(model, key, "cpu", value)
else:
- set_module_quantized_tensor_to_device(
- model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
- )
+ hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict)
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
@@ -4278,14 +3909,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if is_fsdp_enabled() and not is_local_dist_rank_0():
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
- if not (is_quantized):
+ if hf_quantizer is None:
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
- set_module_quantized_tensor_to_device(
- model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
- )
+ hf_quantizer.create_quantized_param(model, param, key, "cpu", state_dict)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
@@ -4299,7 +3928,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
- is_quantized=is_quantized,
+ hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
@@ -4407,7 +4036,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return retrieved_modules
@staticmethod
- def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
+ def _load_pretrained_model_low_mem(
+ model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None
+ ):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
@@ -4422,12 +4053,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
- Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
+ Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To
+ handle bitsandbytes, needs non-empty hf_quantizer argument.
"""
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file)
- error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
+ expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
+ error_msgs = _load_state_dict_into_meta_model(
+ model,
+ state_dict,
+ loaded_state_dict_keys,
+ start_prefix,
+ expected_keys=expected_keys,
+ hf_quantizer=hf_quantizer,
+ )
return error_msgs
@classmethod
diff --git a/src/transformers/quantizers/__init__.py b/src/transformers/quantizers/__init__.py
new file mode 100644
index 0000000000..3409af4cd7
--- /dev/null
+++ b/src/transformers/quantizers/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .auto import AutoHfQuantizer, AutoQuantizationConfig
+from .base import HfQuantizer
diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py
new file mode 100644
index 0000000000..549c4fe132
--- /dev/null
+++ b/src/transformers/quantizers/auto.py
@@ -0,0 +1,148 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+from typing import Dict, Optional, Union
+
+from ..models.auto.configuration_auto import AutoConfig
+from ..utils.quantization_config import (
+ AwqConfig,
+ BitsAndBytesConfig,
+ GPTQConfig,
+ QuantizationConfigMixin,
+ QuantizationMethod,
+)
+from .quantizer_awq import AwqQuantizer
+from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
+from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
+from .quantizer_gptq import GptqHfQuantizer
+
+
+AUTO_QUANTIZER_MAPPING = {
+ "awq": AwqQuantizer,
+ "bitsandbytes_4bit": Bnb4BitHfQuantizer,
+ "bitsandbytes_8bit": Bnb8BitHfQuantizer,
+ "gptq": GptqHfQuantizer,
+}
+
+AUTO_QUANTIZATION_CONFIG_MAPPING = {
+ "awq": AwqConfig,
+ "bitsandbytes_4bit": BitsAndBytesConfig,
+ "bitsandbytes_8bit": BitsAndBytesConfig,
+ "gptq": GPTQConfig,
+}
+
+
+class AutoQuantizationConfig:
+ """
+ The Auto-HF quantization config class that takes care of automatically dispatching to the correct
+ quantization config given a quantization config stored in a dictionary.
+ """
+
+ @classmethod
+ def from_dict(cls, quantization_config_dict: Dict):
+ quant_method = quantization_config_dict.get("quant_method", None)
+ # We need a special care for bnb models to make sure everything is BC ..
+ if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
+ suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
+ quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
+ elif quant_method is None:
+ raise ValueError(
+ "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
+ )
+
+ if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
+ raise ValueError(
+ f"Unknown quantization type, got {quant_method} - supported types are:"
+ f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
+ )
+
+ target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
+ return target_cls.from_dict(quantization_config_dict)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ if getattr(model_config, "quantization_config", None) is None:
+ raise ValueError(
+ f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
+ )
+ quantization_config_dict = model_config.quantization_config
+ quantization_config = cls.from_dict(quantization_config_dict)
+ # Update with potential kwargs that are passed through from_pretrained.
+ quantization_config.update(kwargs)
+ return quantization_config
+
+
+class AutoHfQuantizer:
+ """
+ The Auto-HF quantizer class that takes care of automatically instantiating to the correct
+ `HfQuantizer` given the `QuantizationConfig`.
+ """
+
+ @classmethod
+ def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
+ # Convert it to a QuantizationConfig if the q_config is a dict
+ if isinstance(quantization_config, dict):
+ quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
+
+ quant_method = quantization_config.quant_method
+
+ # Again, we need a special care for bnb as we have a single quantization config
+ # class for both 4-bit and 8-bit quantization
+ if quant_method == QuantizationMethod.BITS_AND_BYTES:
+ if quantization_config.load_in_8bit:
+ quant_method += "_8bit"
+ else:
+ quant_method += "_4bit"
+
+ if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
+ raise ValueError(
+ f"Unknown quantization type, got {quant_method} - supported types are:"
+ f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
+ )
+
+ target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
+ return target_cls(quantization_config, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ quantization_config = AutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ return cls.from_config(quantization_config)
+
+ @classmethod
+ def merge_quantization_configs(
+ cls,
+ quantization_config: Union[dict, QuantizationConfigMixin],
+ quantization_config_from_args: Optional[QuantizationConfigMixin],
+ ):
+ """
+ handles situations where both quantization_config from args and quantization_config from model config are present.
+ """
+ warning_msg = (
+ "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
+ " already has a `quantization_config` attribute. The `quantization_config` from the model will be prevail."
+ )
+
+ if isinstance(quantization_config, dict):
+ quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
+
+ if isinstance(quantization_config, (GPTQConfig, AwqConfig)) and quantization_config_from_args is not None:
+ # special case for GPTQ / AWQ config collision
+ loading_attr_dict = quantization_config_from_args.get_loading_attributes()
+ for attr, val in loading_attr_dict.items():
+ setattr(quantization_config, attr, val)
+ warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
+
+ warnings.warn(warning_msg)
+ return quantization_config
diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py
new file mode 100644
index 0000000000..c8eb8bacaa
--- /dev/null
+++ b/src/transformers/quantizers/base.py
@@ -0,0 +1,220 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Dict, Optional, Union
+
+from ..utils import is_torch_available
+from ..utils.import_utils import _is_package_available
+from ..utils.quantization_config import QuantizationConfigMixin
+
+
+if TYPE_CHECKING:
+ from ..modeling_utils import PreTrainedModel
+
+if is_torch_available():
+ import torch
+
+
+class HfQuantizer(ABC):
+ """
+ Abstract class of the HuggingFace quantizer. Supports for now quantizing HF transformers models for inference and/or quantization.
+ This class is used only for transformers.PreTrainedModel.from_pretrained and cannot be easily used outside the scope of that method
+ yet.
+
+ Attributes
+ quantization_config (`transformers.utils.quantization_config.QuantizationConfigMixin`):
+ The quantization config that defines the quantization parameters of your model that you want to quantize.
+ modules_to_not_convert (`List[str]`, *optional*):
+ The list of module names to not convert when quantizing the model.
+ required_packages (`List[str]`, *optional*):
+ The list of required pip packages to install prior to using the quantizer
+ requires_calibration (`bool`):
+ Whether the quantization method requires to calibrate the model before using it.
+ requires_parameters_quantization (`bool`):
+ Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is
+ required to create a new xxxParameter in order to properly quantize the model.
+ """
+
+ requires_calibration = False
+ required_packages = None
+ requires_parameters_quantization = False
+
+ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
+ self.quantization_config = quantization_config
+
+ # -- Handle extra kwargs below --
+ self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
+ self.pre_quantized = kwargs.pop("pre_quantized", True)
+
+ if not self.pre_quantized and self.requires_calibration:
+ raise ValueError(
+ f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized."
+ f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to "
+ f"pass `pre_quantized=True` while knowing what you are doing."
+ )
+
+ self.check_packages_compatibility()
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ """
+ Some quantization methods require to explicitly set the dtype of the model to a
+ target dtype. You need to override this method in case you want to make sure that behavior is
+ preserved
+
+ Args:
+ torch_dtype (`torch.dtype`):
+ The input dtype that is passed in `from_pretrained`
+ """
+ return torch_dtype
+
+ def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+ """
+ Override this method if you want to pass a override the existing device map with a new
+ one. E.g. for bitsandbytes, since `accelerate` is a hard requirement, if no device_map is
+ passed, the device_map is set to `"auto"``
+
+ Args:
+ device_map (`Union[dict, str]`, *optional*):
+ The device_map that is passed through the `from_pretrained` method.
+ """
+ return device_map
+
+ def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ """
+ Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained`
+ to compute the device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype`
+ to `torch.int8` and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`.
+
+ Args:
+ torch_dtype (`torch.dtype`, *optional*):
+ The torch_dtype that is used to compute the device_map.
+ """
+ return torch_dtype
+
+ def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
+ """
+ returns dtypes for modules that are not quantized - used for the computation of the device_map in case
+ one passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified
+ in `_process_model_before_weight_loading`.
+
+ Args:
+ model (`~transformers.PreTrainedModel`):
+ The model to quantize
+ torch_dtype (`torch.dtype`):
+ The dtype passed in `from_pretrained` method.
+ """
+ return {
+ name: torch_dtype
+ for name, _ in model.named_parameters()
+ if any(m in name for m in self.modules_to_not_convert)
+ }
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
+ return max_memory
+
+ def check_quantized_param(
+ self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
+ ) -> bool:
+ """
+ checks if a loaded state_dict component is part of quantized param + some validation; only defined if
+ requires_parameters_quantization == True for quantization methods that require to create a new parameters
+ for quantization.
+ """
+ return False
+
+ def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
+ """
+ takes needed components from state_dict and creates quantized param; only applicable if
+ requires_parameters_quantization == True
+ """
+ if not self.requires_parameters_quantization:
+ raise AttributeError(
+ f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
+ )
+
+ def validate_environment(self, *args, **kwargs):
+ """
+ This method is used to potentially check for potential conflicts with arguments that are
+ passed in `from_pretrained`. You need to define it for all future quantizers that are integrated with transformers.
+ If no explicit check are needed, simply return nothing.
+ """
+ return
+
+ def check_packages_compatibility(self):
+ """
+ Check the compatibility of the quantizer with respect to the current environment. Loops over all packages
+ name under `self.required_packages` and checks if that package is available.
+ """
+ if self.required_packages is not None:
+ non_available_packages = []
+ for package_name in self.required_packages:
+ is_package_available = _is_package_available(package_name)
+ if not is_package_available:
+ non_available_packages.append(package_name)
+
+ if len(non_available_packages) > 0:
+ raise ValueError(
+ f"The packages {self.required_packages} are required to use {self.__class__.__name__}"
+ f" the following packages are missing in your environment: {non_available_packages}, please make sure"
+ f" to install them in order to use the quantizer."
+ )
+
+ def preprocess_model(self, model: "PreTrainedModel", **kwargs):
+ """
+ Setting model attributes and/or converting model before weights loading. At this point
+ the model should be initialized on the meta device so you can freely manipulate the skeleton
+ of the model in order to replace modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.
+
+ Args:
+ model (`~transformers.PreTrainedModel`):
+ The model to quantize
+ kwargs (`dict`, *optional*):
+ The keyword arguments that are passed along `_process_model_before_weight_loading`.
+ """
+ model.is_quantized = True
+ model.quantization_method = self.quantization_config.quant_method
+ return self._process_model_before_weight_loading(model, **kwargs)
+
+ def postprocess_model(self, model: "PreTrainedModel", **kwargs):
+ """
+ Post-process the model post weights loading.
+ Make sure to override the abstract method `_process_model_after_weight_loading`.
+
+ Args:
+ model (`~transformers.PreTrainedModel`):
+ The model to quantize
+ kwargs (`dict`, *optional*):
+ The keyword arguments that are passed along `_process_model_after_weight_loading`.
+ """
+ model._is_quantized_training_enabled = self.is_trainable
+ return self._process_model_after_weight_loading(model, **kwargs)
+
+ @abstractmethod
+ def _process_model_before_weight_loading(self, model, **kwargs):
+ ...
+
+ @abstractmethod
+ def _process_model_after_weight_loading(self, model, **kwargs):
+ ...
+
+ @property
+ @abstractmethod
+ def is_serializable(self):
+ ...
+
+ @property
+ @abstractmethod
+ def is_trainable(self):
+ ...
diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py
new file mode 100644
index 0000000000..3e10730994
--- /dev/null
+++ b/src/transformers/quantizers/quantizer_awq.py
@@ -0,0 +1,110 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from .base import HfQuantizer
+
+
+if TYPE_CHECKING:
+ from ..modeling_utils import PreTrainedModel
+
+from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class AwqQuantizer(HfQuantizer):
+ """
+ 4-bit quantization for Activation-aware Weight Quantization(AWQ) (https://arxiv.org/abs/2306.00978)
+ """
+
+ # AWQ requires data callibration - we support only inference
+ requires_calibration = True
+
+ required_packages = ["awq", "accelerate"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ def validate_environment(self, device_map, **kwargs):
+ if not torch.cuda.is_available():
+ raise RuntimeError("GPU is required to run AWQ quantized model.")
+
+ if not is_auto_awq_available():
+ raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")
+
+ if not is_accelerate_available():
+ raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
+
+ if device_map is None:
+ logger.warning_once(
+ "You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
+ "your model on a GPU device in order to run your model."
+ )
+ elif device_map is not None:
+ if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
+ raise ValueError(
+ "You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
+ " This is not supported. Please remove the CPU or disk device from the device_map."
+ )
+
+ def update_torch_dtype(self, torch_dtype):
+ if torch_dtype is None:
+ torch_dtype = torch.float16
+ elif torch_dtype != torch.float16:
+ logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
+ return torch_dtype
+
+ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ from ..integrations import get_keys_to_not_convert, replace_with_awq_linear
+
+ self.modules_to_not_convert = get_keys_to_not_convert(model)
+
+ if self.quantization_config.modules_to_not_convert is not None:
+ self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
+
+ model, has_been_replaced = replace_with_awq_linear(
+ model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
+ )
+
+ if not has_been_replaced:
+ logger.warning(
+ "You are loading an AWQ model but no linear modules were found in your model."
+ " Please double check your model architecture, or submit an issue on github if you think this is a bug."
+ )
+
+ def _process_model_after_weight_loading(self, model):
+ if self.quantization_config.do_fuse:
+ from ..integrations import fuse_awq_modules
+
+ model = fuse_awq_modules(model, self.quantization_config)
+ model._awq_is_fused = True # TODO: consider storing this flag in model.config instead
+
+ @property
+ def is_serializable(self):
+ # AWQ through auto-awq has been always serializable, except if the model is fused.
+ if self.quantization_config.do_fuse:
+ logger.warning("You cannot save an AWQ model that uses fused modules!")
+ return False
+ return True
+
+ @property
+ def is_trainable(self):
+ # AWQ does not support neither QAT (Quantization Aware Training or PEFT yet.)
+ # TODO: if this is supported in the future, do a version check here.
+ return False
diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py
new file mode 100644
index 0000000000..7cc9ef6560
--- /dev/null
+++ b/src/transformers/quantizers/quantizer_bnb_4bit.py
@@ -0,0 +1,312 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from packaging import version
+
+from .base import HfQuantizer
+from .quantizers_utils import get_module_from_name
+
+
+if TYPE_CHECKING:
+ from ..modeling_utils import PreTrainedModel
+
+from ..utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+ from ..pytorch_utils import Conv1D
+
+logger = logging.get_logger(__name__)
+
+
+class Bnb4BitHfQuantizer(HfQuantizer):
+ """
+ 4-bit quantization from bitsandbytes.py quantization method:
+ before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the
+ layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call
+ saving:
+ from state dict, as usual; saves weights and `quant_state` components
+ loading:
+ need to locate `quant_state` components and pass to Param4bit constructor
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_parameters_quantization = True
+ requires_calibration = False
+
+ required_packages = ["bitsandbytes", "accelerate"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ if self.quantization_config.llm_int8_skip_modules is not None:
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ def validate_environment(self, *args, **kwargs):
+ if not (is_accelerate_available() and is_bitsandbytes_available()):
+ raise ImportError(
+ "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
+ "and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`"
+ )
+
+ if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
+ raise ValueError(
+ "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
+ " sure the weights are in PyTorch format."
+ )
+
+ if not torch.cuda.is_available():
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
+
+ device_map = kwargs.get("device_map", None)
+ if (
+ device_map is not None
+ and isinstance(device_map, dict)
+ and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
+ ):
+ device_map_without_lm_head = {
+ key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
+ }
+ if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
+ raise ValueError(
+ """
+ Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the
+ quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules
+ in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to
+ `from_pretrained`. Check
+ https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
+ for more details.
+ """
+ )
+
+ if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"):
+ raise ValueError(
+ "You have a version of `bitsandbytes` that is not compatible with 4bit inference and training"
+ " make sure you have the latest version of `bitsandbytes` installed"
+ )
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
+ from accelerate.utils import CustomDtype
+
+ if target_dtype != torch.int8:
+ logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
+ return CustomDtype.INT4
+ else:
+ raise ValueError(
+ "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute"
+ " the appropriate device map, you should upgrade your `accelerate` library,"
+ "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map"
+ "calculation. You may encounter unexpected behavior, or pass your own device map"
+ )
+
+ def check_quantized_param(
+ self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
+ ) -> bool:
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
+ # Add here check for loaded components' dtypes once serialization is implemented
+ return True
+ elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
+ # bias could be loaded by regular set_module_tensor_to_device() from accelerate,
+ # but it would wrongly use uninitialized weight there.
+ return True
+ else:
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "PreTrainedModel",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Dict[str, Any],
+ unexpected_keys: List[str],
+ ):
+ """
+ combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device()
+ """
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+
+ if tensor_name not in module._parameters:
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
+
+ old_value = getattr(module, tensor_name)
+
+ if tensor_name == "bias":
+ if param_value is None:
+ new_value = old_value.to(target_device)
+ else:
+ new_value = param_value.to(target_device)
+
+ new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
+ module._parameters[tensor_name] = new_value
+ return
+
+ if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
+ raise ValueError("this function only loads `Linear4bit components`")
+ if (
+ old_value.device == torch.device("meta")
+ and target_device not in ["meta", torch.device("meta")]
+ and param_value is None
+ ):
+ raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
+
+ # construct `new_value` for the module._parameters[tensor_name]:
+ if self.pre_quantized:
+ # 4bit loading. Collecting components for restoring quantized weight
+ # This can be expanded to make a universal call for any quantized weight loading
+
+ if not self.is_serializable:
+ raise ValueError(
+ "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. "
+ "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
+ )
+
+ if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
+ param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
+ ):
+ raise ValueError(
+ f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
+ )
+
+ quantized_stats = {}
+ for k, v in state_dict.items():
+ if param_name + "." in k:
+ quantized_stats[k] = v
+ unexpected_keys.remove(k)
+
+ new_value = bnb.nn.Params4bit.from_prequantized(
+ data=param_value,
+ quantized_stats=quantized_stats,
+ requires_grad=False,
+ device=target_device,
+ )
+ else:
+ new_value = param_value.to("cpu")
+
+ # Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
+ # Since weights are saved in the correct "orientation", we skip transposing when loading.
+ if issubclass(module.source_cls, Conv1D):
+ new_value = new_value.T
+
+ kwargs = old_value.__dict__
+ new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
+
+ module._parameters[tensor_name] = new_value
+
+ # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ # need more space for buffers that are created during quantization
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_torch_dtype
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
+ logger.info(
+ "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
+ "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
+ "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
+ " torch_dtype=torch.float16 to remove this warning.",
+ torch_dtype,
+ )
+ torch_dtype = torch.float16
+ return torch_dtype
+
+ # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map
+ def update_device_map(self, device_map):
+ if device_map is None:
+ device_map = {"": torch.cuda.current_device()}
+ logger.info(
+ "The device_map was not initialized. "
+ "Setting device_map to {'':torch.cuda.current_device()}. "
+ "If you want to use the model for inference, please set device_map ='auto' "
+ )
+ return device_map
+
+ # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading
+ def _process_model_before_weight_loading(
+ self,
+ model: "PreTrainedModel",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
+
+ load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
+
+ # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
+ if self.quantization_config.llm_int8_skip_modules is None:
+ self.modules_to_not_convert = get_keys_to_not_convert(model)
+ else:
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+
+ if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
+ raise ValueError(
+ "If you want to offload some keys to `cpu` or `disk`, you need to set "
+ "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
+ " converted to 8-bit but kept in 32-bit."
+ )
+ self.modules_to_not_convert.extend(keys_on_cpu)
+
+ model = replace_with_bnb_linear(
+ model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
+
+ model.config.quantization_config = self.quantization_config
+
+ # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
+ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ model._is_quantized_training_enabled = self.is_trainable
+ model.is_loaded_in_4bit = True
+ model.is_4bit_serializable = self.is_serializable
+ return model
+
+ @property
+ def is_serializable(self):
+ _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
+
+ if not _is_4bit_serializable:
+ logger.warning(
+ "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
+ "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
+ )
+ return False
+
+ return True
+
+ @property
+ def is_trainable(self) -> bool:
+ return True
diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py
new file mode 100644
index 0000000000..6428b13c25
--- /dev/null
+++ b/src/transformers/quantizers/quantizer_bnb_8bit.py
@@ -0,0 +1,272 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from packaging import version
+
+from .base import HfQuantizer
+
+
+if TYPE_CHECKING:
+ from ..modeling_utils import PreTrainedModel
+
+from ..utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
+from .quantizers_utils import get_module_from_name
+
+
+if is_torch_available():
+ import torch
+
+ from ..pytorch_utils import Conv1D
+
+logger = logging.get_logger(__name__)
+
+
+class Bnb8BitHfQuantizer(HfQuantizer):
+ """
+ 8-bit quantization from bitsandbytes quantization method:
+ before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the
+ layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call
+ saving:
+ from state dict, as usual; saves weights and 'SCB' component
+ loading:
+ need to locate SCB component and pass to the Linear8bitLt object
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_parameters_quantization = True
+ requires_calibration = False
+
+ required_packages = ["bitsandbytes", "accelerate"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ if self.quantization_config.llm_int8_skip_modules is not None:
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ def validate_environment(self, *args, **kwargs):
+ if not (is_accelerate_available() and is_bitsandbytes_available()):
+ raise ImportError(
+ "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
+ "and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`"
+ )
+
+ if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
+ raise ValueError(
+ "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
+ " sure the weights are in PyTorch format."
+ )
+
+ if not torch.cuda.is_available():
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
+
+ device_map = kwargs.get("device_map", None)
+ if (
+ device_map is not None
+ and isinstance(device_map, dict)
+ and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
+ ):
+ device_map_without_lm_head = {
+ key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
+ }
+ if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
+ raise ValueError(
+ """
+ Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the
+ quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules
+ in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to
+ `from_pretrained`. Check
+ https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
+ for more details.
+ """
+ )
+
+ if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.2"):
+ raise ValueError(
+ "You have a version of `bitsandbytes` that is not compatible with 8bit inference and training"
+ " make sure you have the latest version of `bitsandbytes` installed"
+ )
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ # need more space for buffers that are created during quantization
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
+ logger.info(
+ "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
+ "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
+ "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
+ " torch_dtype=torch.float16 to remove this warning.",
+ torch_dtype,
+ )
+ torch_dtype = torch.float16
+ return torch_dtype
+
+ def update_device_map(self, device_map):
+ if device_map is None:
+ device_map = {"": torch.cuda.current_device()}
+ logger.info(
+ "The device_map was not initialized. "
+ "Setting device_map to {'':torch.cuda.current_device()}. "
+ "If you want to use the model for inference, please set device_map ='auto' "
+ )
+ return device_map
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if target_dtype != torch.int8:
+ logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
+ return torch.int8
+
+ def check_quantized_param(
+ self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
+ ):
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
+ if self.pre_quantized:
+ if param_name.replace("weight", "SCB") not in state_dict.keys():
+ raise ValueError("Missing quantization component `SCB`")
+ if param_value.dtype != torch.int8:
+ raise ValueError(
+ f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`."
+ )
+ return True
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "PreTrainedModel",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Dict[str, Any],
+ unexpected_keys: List[str],
+ ):
+ """
+ combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device()
+ needs aux items from state dicts, if found - removes them from unexpected_keys
+ """
+ import bitsandbytes as bnb
+
+ fp16_statistics_key = param_name.replace("weight", "SCB")
+ fp16_statistics = state_dict.get(fp16_statistics_key, None)
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if tensor_name not in module._parameters:
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
+
+ old_value = getattr(module, tensor_name)
+
+ if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
+ raise ValueError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.")
+ if (
+ old_value.device == torch.device("meta")
+ and target_device not in ["meta", torch.device("meta")]
+ and param_value is None
+ ):
+ raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
+
+ new_value = param_value.to("cpu")
+ if self.pre_quantized and not self.is_serializable:
+ raise ValueError(
+ "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
+ "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
+ )
+
+ # Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
+ # Since weights are saved in the correct "orientation", we skip transposing when loading.
+ if issubclass(module.source_cls, Conv1D):
+ if fp16_statistics is None:
+ new_value = new_value.T
+
+ kwargs = old_value.__dict__
+ new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device)
+
+ module._parameters[tensor_name] = new_value
+ if fp16_statistics is not None:
+ setattr(module.weight, "SCB", fp16_statistics.to(target_device))
+ unexpected_keys.remove(fp16_statistics_key)
+
+ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ model._is_quantized_training_enabled = self.is_trainable
+ model.is_loaded_in_8bit = True
+ model.is_8bit_serializable = self.is_serializable
+ return model
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "PreTrainedModel",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
+
+ load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
+
+ # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
+ if self.quantization_config.llm_int8_skip_modules is None:
+ self.modules_to_not_convert = get_keys_to_not_convert(model)
+ else:
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+
+ if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
+ raise ValueError(
+ "If you want to offload some keys to `cpu` or `disk`, you need to set "
+ "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
+ " converted to 8-bit but kept in 32-bit."
+ )
+ self.modules_to_not_convert.extend(keys_on_cpu)
+
+ model = replace_with_bnb_linear(
+ model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
+
+ model.config.quantization_config = self.quantization_config
+
+ @property
+ def is_serializable(self):
+ _bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
+ "0.37.2"
+ )
+
+ if not _bnb_supports_8bit_serialization:
+ logger.warning(
+ "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
+ "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed. You will most likely face errors or"
+ " unexpected behaviours."
+ )
+ return False
+
+ return True
+
+ @property
+ def is_trainable(self) -> bool:
+ return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0")
diff --git a/src/transformers/quantizers/quantizer_gptq.py b/src/transformers/quantizers/quantizer_gptq.py
new file mode 100644
index 0000000000..ffc6f2090a
--- /dev/null
+++ b/src/transformers/quantizers/quantizer_gptq.py
@@ -0,0 +1,94 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+from typing import TYPE_CHECKING, Optional
+
+from packaging import version
+
+from .base import HfQuantizer
+
+
+if TYPE_CHECKING:
+ from ..modeling_utils import PreTrainedModel
+
+from ..utils import is_auto_gptq_available, is_optimum_available, is_torch_available, logging
+from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class GptqHfQuantizer(HfQuantizer):
+ """
+ Quantizer of the GPTQ method - for GPTQ the quantizer support calibration of the model through
+ `auto_gptq` package. Quantization is done under the hood for users if they load a non-prequantized model.
+ """
+
+ requires_calibration = False
+ required_packages = ["optimum", "auto_gptq"]
+ optimum_quantizer = None
+
+ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+ from optimum.gptq import GPTQQuantizer
+
+ self.optimum_quantizer = GPTQQuantizer.from_dict(self.quantization_config.to_dict_optimum())
+
+ def validate_environment(self, *args, **kwargs):
+ gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
+ if not gptq_supports_cpu and not torch.cuda.is_available():
+ raise RuntimeError("GPU is required to quantize or run quantize model.")
+ elif not (is_optimum_available() and is_auto_gptq_available()):
+ raise ImportError(
+ "Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq library (`pip install auto-gptq`)"
+ )
+ elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
+ raise ImportError(
+ "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`"
+ )
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ torch_dtype = torch.float16
+ elif torch_dtype != torch.float16:
+ logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
+ return torch_dtype
+
+ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ if model.__class__.main_input_name != "input_ids":
+ raise RuntimeError("We can only quantize pure text model.")
+
+ if self.pre_quantized:
+ model = self.optimum_quantizer.convert_model(model)
+
+ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ if self.pre_quantized:
+ model = self.optimum_quantizer.post_init_model(model)
+ else:
+ if self.quantization_config.tokenizer is None:
+ self.quantization_config.tokenizer = model.name_or_path
+
+ self.optimum_quantizer.quantize_model(model, self.quantization_config.tokenizer)
+ model.config.quantization_config = GPTQConfig.from_dict(self.optimum_quantizer.to_dict())
+
+ @property
+ def is_trainable(self, model: Optional["PreTrainedModel"] = None):
+ return True
+
+ @property
+ def is_serializable(self):
+ return True
diff --git a/src/transformers/quantizers/quantizers_utils.py b/src/transformers/quantizers/quantizers_utils.py
new file mode 100644
index 0000000000..6ae287bf25
--- /dev/null
+++ b/src/transformers/quantizers/quantizers_utils.py
@@ -0,0 +1,26 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Tuple
+
+
+def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
+ if "." in tensor_name:
+ splits = tensor_name.split(".")
+ for split in splits[:-1]:
+ new_module = getattr(module, split)
+ if new_module is None:
+ raise ValueError(f"{module} has no attribute {split}.")
+ module = new_module
+ tensor_name = splits[-1]
+ return module, tensor_name
diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py
index f4c91dbf4d..13627ab9de 100644
--- a/src/transformers/utils/quantization_config.py
+++ b/src/transformers/utils/quantization_config.py
@@ -125,6 +125,11 @@ class QuantizationConfigMixin:
"""
return copy.deepcopy(self.__dict__)
+ def __iter__(self):
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
+ for attr, value in copy.deepcopy(self.__dict__).items():
+ yield attr, value
+
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
@@ -146,6 +151,29 @@ class QuantizationConfigMixin:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+ # Copied from transformers.generation.configuration_utils.GenerationConfig.update
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
+ returning all the unused kwargs.
+
+ Args:
+ kwargs (`Dict[str, Any]`):
+ Dictionary of attributes to tentatively update this class.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+
+ # remove all the attributes that were updated, without modifying the input dict
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
+
@dataclass
class BitsAndBytesConfig(QuantizationConfigMixin):
diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py
index be10eb8f91..a2dbd904a5 100644
--- a/tests/quantization/autoawq/test_awq.py
+++ b/tests/quantization/autoawq/test_awq.py
@@ -55,8 +55,15 @@ class AwqConfigTest(unittest.TestCase):
with self.assertRaises(ValueError):
AwqConfig(bits=4, backend="unexisting-backend")
- # LLMAWQ does not work on a T4
- with self.assertRaises(ValueError):
+ compute_capability = torch.cuda.get_device_capability()
+ major, minor = compute_capability
+
+ if major < 8:
+ # LLMAWQ does not work on a T4
+ with self.assertRaises(ValueError):
+ AwqConfig(bits=4, backend="llm-awq")
+ else:
+ # LLMAWQ should work on an A100
AwqConfig(bits=4, backend="llm-awq")
def test_to_dict(self):
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index da2ce55d31..0ce7274d25 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -95,7 +95,8 @@ class BaseMixedInt8Test(unittest.TestCase):
)
input_text = "Hello my name is"
- EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of the family.\n"
+ EXPECTED_OUTPUTS = set()
+ EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of the family.\n")
MAX_NEW_TOKENS = 10
def setUp(self):
@@ -260,7 +261,7 @@ class MixedInt8Test(BaseMixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = self.model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
- self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_generate_quality_config(self):
r"""
@@ -278,7 +279,7 @@ class MixedInt8Test(BaseMixedInt8Test):
input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10
)
- self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_raise_if_config_and_load_in_8bit(self):
r"""
@@ -365,9 +366,7 @@ class MixedInt8Test(BaseMixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
- self.assertEqual(
- self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
- )
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_int8_serialization_regression(self):
r"""
@@ -392,9 +391,7 @@ class MixedInt8Test(BaseMixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
- self.assertEqual(
- self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
- )
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_int8_serialization_sharded(self):
r"""
@@ -419,9 +416,7 @@ class MixedInt8Test(BaseMixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
- self.assertEqual(
- self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
- )
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_int8_from_pretrained(self):
r"""
@@ -441,7 +436,7 @@ class MixedInt8Test(BaseMixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
- self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_bitsandbytes
@@ -628,7 +623,7 @@ class MixedInt8TestPipeline(BaseMixedInt8Test):
# Real second forward pass
pipeline_output = self.pipe(self.input_text)
- self.assertEqual(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUT)
+ self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
@@ -654,7 +649,7 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
# Second real batch
output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
- self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
+ self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
@@ -671,7 +666,7 @@ class MixedInt8TestCpuGpu(BaseMixedInt8Test):
# Get the generation
output_text = self.tokenizer.decode(output_parallel[0], skip_special_tokens=True)
- self.assertEqual(output_text, self.EXPECTED_OUTPUT)
+ self.assertIn(output_text, self.EXPECTED_OUTPUTS)
def test_cpu_gpu_loading_random_device_map(self):
r"""
@@ -708,7 +703,7 @@ class MixedInt8TestCpuGpu(BaseMixedInt8Test):
"transformer.ln_f": 1,
}
- bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
+ bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True, load_in_8bit=True)
model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name,
@@ -734,7 +729,7 @@ class MixedInt8TestCpuGpu(BaseMixedInt8Test):
"transformer.h": 0,
"transformer.ln_f": 1,
}
- bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
+ bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True, load_in_8bit=True)
# Load model
model_8bit = AutoModelForCausalLM.from_pretrained(
@@ -760,7 +755,7 @@ class MixedInt8TestCpuGpu(BaseMixedInt8Test):
"transformer.h": 1,
"transformer.ln_f": "disk",
}
- bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
+ bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True, load_in_8bit=True)
with tempfile.TemporaryDirectory() as tmpdirname:
# Load model
model_8bit = AutoModelForCausalLM.from_pretrained(
@@ -849,7 +844,9 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
class MixedInt8GPT2Test(MixedInt8Test):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
- EXPECTED_OUTPUT = "Hello my name is John Doe, and I'm a big fan of"
+ EXPECTED_OUTPUTS = set()
+ EXPECTED_OUTPUTS.add("Hello my name is John Doe, and I'm a big fan of")
+ EXPECTED_OUTPUTS.add("Hello my name is John Doe, and I'm a fan of the")
def test_int8_from_pretrained(self):
r"""
@@ -869,4 +866,4 @@ class MixedInt8GPT2Test(MixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
- self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt
index 6087758f5a..04a400d8a9 100644
--- a/utils/not_doctested.txt
+++ b/utils/not_doctested.txt
@@ -943,6 +943,13 @@ src/transformers/pipelines/zero_shot_image_classification.py
src/transformers/pipelines/zero_shot_object_detection.py
src/transformers/processing_utils.py
src/transformers/pytorch_utils.py
+src/transformers/quantizers/auto.py
+src/transformers/quantizers/base.py
+src/transformers/quantizers/quantizer_awq.py
+src/transformers/quantizers/quantizer_bnb_4bit.py
+src/transformers/quantizers/quantizer_bnb_8bit.py
+src/transformers/quantizers/quantizer_gptq.py
+src/transformers/quantizers/quantizers_utils.py
src/transformers/sagemaker/trainer_sm.py
src/transformers/sagemaker/training_args_sm.py
src/transformers/testing_utils.py