FEAT / Bitsandbytes: Add dequantize API for bitsandbytes quantized models (#30806)
* add method * change method name * more comments * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fixup * add docstrings and fix comment * warn users on the de-quantized dtype * Update src/transformers/quantizers/base.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/integrations/bitsandbytes.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final suggestion - use private method --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -642,6 +642,27 @@ double_quant_config = BitsAndBytesConfig(
|
||||
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
|
||||
```
|
||||
|
||||
### Dequantizing `bitsandbytes` models
|
||||
|
||||
Once quantized, you can dequantize the model to the original precision. Note this might result in a small quality loss of the model. Make also sure to have enough GPU RAM to fit the dequantized model.
|
||||
Below is how to perform dequantization on a 4-bit model using `bitsandbytes`.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
|
||||
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, BitsAndBytesConfig(load_in_4bit=True))
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
model.dequantize()
|
||||
|
||||
text = tokenizer("Hello my name is", return_tensors="pt").to(0)
|
||||
|
||||
out = model.generate(**text)
|
||||
print(tokenizer.decode(out[0]))
|
||||
```
|
||||
|
||||
## EETQ
|
||||
The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization.
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ _import_structure = {
|
||||
"replace_with_awq_linear",
|
||||
],
|
||||
"bitsandbytes": [
|
||||
"dequantize_and_replace",
|
||||
"get_keys_to_not_convert",
|
||||
"replace_8bit_linear",
|
||||
"replace_with_bnb_linear",
|
||||
@@ -105,6 +106,7 @@ if TYPE_CHECKING:
|
||||
replace_with_awq_linear,
|
||||
)
|
||||
from .bitsandbytes import (
|
||||
dequantize_and_replace,
|
||||
get_keys_to_not_convert,
|
||||
replace_8bit_linear,
|
||||
replace_with_bnb_linear,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from inspect import signature
|
||||
@@ -16,7 +17,9 @@ if is_bitsandbytes_available():
|
||||
from ..pytorch_utils import Conv1D
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.utils import find_tied_parameters
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -322,3 +325,141 @@ def get_keys_to_not_convert(model):
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
||||
|
||||
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
|
||||
def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
|
||||
"""
|
||||
Helper function to dequantize 4bit or 8bit bnb weights.
|
||||
|
||||
If the weight is not a bnb quantized weight, it will be returned as is.
|
||||
"""
|
||||
if not isinstance(weight, torch.nn.Parameter):
|
||||
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
|
||||
|
||||
cls_name = weight.__class__.__name__
|
||||
if cls_name not in ("Params4bit", "Int8Params"):
|
||||
return weight
|
||||
|
||||
if cls_name == "Params4bit":
|
||||
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
|
||||
logger.warning_once(
|
||||
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
if state.SCB is None:
|
||||
state.SCB = weight.SCB
|
||||
|
||||
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
|
||||
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
|
||||
im, Sim = bnb.functional.transform(im, "col32")
|
||||
if state.CxB is None:
|
||||
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
|
||||
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
|
||||
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
|
||||
|
||||
|
||||
def _create_accelerate_new_hook(old_hook):
|
||||
r"""
|
||||
Creates a new hook based on the old hook. Use it only if you know what you are doing !
|
||||
This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
|
||||
with some changes
|
||||
"""
|
||||
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
|
||||
old_hook_attr = old_hook.__dict__
|
||||
filtered_old_hook_attr = {}
|
||||
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
|
||||
for k in old_hook_attr.keys():
|
||||
if k in old_hook_init_signature.parameters:
|
||||
filtered_old_hook_attr[k] = old_hook_attr[k]
|
||||
new_hook = old_hook_cls(**filtered_old_hook_attr)
|
||||
return new_hook
|
||||
|
||||
|
||||
def _dequantize_and_replace(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Converts a quantized model into its dequantized original version. The newly converted model will have
|
||||
some performance drop compared to the original model before quantization - use it only for specific usecases
|
||||
such as QLoRA adapters merging.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
quant_method = quantization_config.quantization_method()
|
||||
|
||||
target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, target_cls) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
|
||||
if not any(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
bias = getattr(module, "bias", None)
|
||||
|
||||
device = module.weight.device
|
||||
with init_empty_weights():
|
||||
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
|
||||
|
||||
if quant_method == "llm_int8":
|
||||
state = module.state
|
||||
else:
|
||||
state = None
|
||||
|
||||
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
|
||||
|
||||
if bias is not None:
|
||||
new_module.bias = bias
|
||||
|
||||
# Create a new hook and attach it in case we use accelerate
|
||||
if hasattr(module, "_hf_hook"):
|
||||
old_hook = module._hf_hook
|
||||
new_hook = _create_accelerate_new_hook(old_hook)
|
||||
|
||||
remove_hook_from_module(module)
|
||||
add_hook_to_module(new_module, new_hook)
|
||||
|
||||
new_module.to(device)
|
||||
model._modules[name] = new_module
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _dequantize_and_replace(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def dequantize_and_replace(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
quantization_config=None,
|
||||
):
|
||||
model, has_been_replaced = _dequantize_and_replace(
|
||||
model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@@ -1327,6 +1327,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
self.init_weights()
|
||||
self._backward_compatibility_gradient_checkpointing()
|
||||
|
||||
def dequantize(self):
|
||||
"""
|
||||
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
||||
dequantization.
|
||||
"""
|
||||
hf_quantizer = getattr(self, "hf_quantizer", None)
|
||||
|
||||
if hf_quantizer is None:
|
||||
raise ValueError("You need to first quantize your model in order to dequantize it")
|
||||
|
||||
return hf_quantizer.dequantize(self)
|
||||
|
||||
def _backward_compatibility_gradient_checkpointing(self):
|
||||
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
|
||||
self.gradient_checkpointing_enable()
|
||||
|
||||
@@ -194,6 +194,23 @@ class HfQuantizer(ABC):
|
||||
"""
|
||||
return self._process_model_after_weight_loading(model, **kwargs)
|
||||
|
||||
def dequantize(self, model):
|
||||
"""
|
||||
Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance.
|
||||
Note not all quantization schemes support this.
|
||||
"""
|
||||
model = self._dequantize(model)
|
||||
|
||||
# Delete quantizer and quantization config
|
||||
del model.hf_quantizer
|
||||
|
||||
return model
|
||||
|
||||
def _dequantize(self, model):
|
||||
raise NotImplementedError(
|
||||
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _process_model_before_weight_loading(self, model, **kwargs):
|
||||
...
|
||||
|
||||
@@ -312,3 +312,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return True
|
||||
|
||||
def _dequantize(self, model):
|
||||
from ..integrations import dequantize_and_replace
|
||||
|
||||
model = dequantize_and_replace(
|
||||
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
)
|
||||
return model
|
||||
|
||||
@@ -281,3 +281,11 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0")
|
||||
|
||||
def _dequantize(self, model):
|
||||
from ..integrations import dequantize_and_replace
|
||||
|
||||
model = dequantize_and_replace(
|
||||
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
)
|
||||
return model
|
||||
|
||||
@@ -239,6 +239,23 @@ class Bnb4BitTest(Base4bitTest):
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_generate_quality_dequantize(self):
|
||||
r"""
|
||||
Test that loading the model and unquantize it produce correct results
|
||||
"""
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model_4bit = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, quantization_config=bnb_config, device_map="auto"
|
||||
)
|
||||
|
||||
model_4bit.dequantize()
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
|
||||
@@ -285,6 +285,23 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_generate_quality_dequantize(self):
|
||||
r"""
|
||||
Test that loading the model and dequantizing it produce correct results
|
||||
"""
|
||||
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
model_8bit = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, quantization_config=bnb_config, device_map="auto"
|
||||
)
|
||||
|
||||
model_8bit.dequantize()
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_raise_if_config_and_load_in_8bit(self):
|
||||
r"""
|
||||
Test that loading the model with the config and `load_in_8bit` raises an error
|
||||
|
||||
Reference in New Issue
Block a user