FEAT / Bitsandbytes: Add dequantize API for bitsandbytes quantized models (#30806)
* add method * change method name * more comments * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fixup * add docstrings and fix comment * warn users on the de-quantized dtype * Update src/transformers/quantizers/base.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/integrations/bitsandbytes.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final suggestion - use private method --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -642,6 +642,27 @@ double_quant_config = BitsAndBytesConfig(
|
|||||||
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
|
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Dequantizing `bitsandbytes` models
|
||||||
|
|
||||||
|
Once quantized, you can dequantize the model to the original precision. Note this might result in a small quality loss of the model. Make also sure to have enough GPU RAM to fit the dequantized model.
|
||||||
|
Below is how to perform dequantization on a 4-bit model using `bitsandbytes`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
|
||||||
|
|
||||||
|
model_id = "facebook/opt-125m"
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, BitsAndBytesConfig(load_in_4bit=True))
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
model.dequantize()
|
||||||
|
|
||||||
|
text = tokenizer("Hello my name is", return_tensors="pt").to(0)
|
||||||
|
|
||||||
|
out = model.generate(**text)
|
||||||
|
print(tokenizer.decode(out[0]))
|
||||||
|
```
|
||||||
|
|
||||||
## EETQ
|
## EETQ
|
||||||
The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization.
|
The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization.
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ _import_structure = {
|
|||||||
"replace_with_awq_linear",
|
"replace_with_awq_linear",
|
||||||
],
|
],
|
||||||
"bitsandbytes": [
|
"bitsandbytes": [
|
||||||
|
"dequantize_and_replace",
|
||||||
"get_keys_to_not_convert",
|
"get_keys_to_not_convert",
|
||||||
"replace_8bit_linear",
|
"replace_8bit_linear",
|
||||||
"replace_with_bnb_linear",
|
"replace_with_bnb_linear",
|
||||||
@@ -105,6 +106,7 @@ if TYPE_CHECKING:
|
|||||||
replace_with_awq_linear,
|
replace_with_awq_linear,
|
||||||
)
|
)
|
||||||
from .bitsandbytes import (
|
from .bitsandbytes import (
|
||||||
|
dequantize_and_replace,
|
||||||
get_keys_to_not_convert,
|
get_keys_to_not_convert,
|
||||||
replace_8bit_linear,
|
replace_8bit_linear,
|
||||||
replace_with_bnb_linear,
|
replace_with_bnb_linear,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
@@ -16,7 +17,9 @@ if is_bitsandbytes_available():
|
|||||||
from ..pytorch_utils import Conv1D
|
from ..pytorch_utils import Conv1D
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
|
import accelerate
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||||
from accelerate.utils import find_tied_parameters
|
from accelerate.utils import find_tied_parameters
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -322,3 +325,141 @@ def get_keys_to_not_convert(model):
|
|||||||
filtered_module_names.append(name)
|
filtered_module_names.append(name)
|
||||||
|
|
||||||
return filtered_module_names
|
return filtered_module_names
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
|
||||||
|
def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
|
||||||
|
"""
|
||||||
|
Helper function to dequantize 4bit or 8bit bnb weights.
|
||||||
|
|
||||||
|
If the weight is not a bnb quantized weight, it will be returned as is.
|
||||||
|
"""
|
||||||
|
if not isinstance(weight, torch.nn.Parameter):
|
||||||
|
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
|
||||||
|
|
||||||
|
cls_name = weight.__class__.__name__
|
||||||
|
if cls_name not in ("Params4bit", "Int8Params"):
|
||||||
|
return weight
|
||||||
|
|
||||||
|
if cls_name == "Params4bit":
|
||||||
|
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
|
||||||
|
logger.warning_once(
|
||||||
|
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
|
||||||
|
)
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
if state.SCB is None:
|
||||||
|
state.SCB = weight.SCB
|
||||||
|
|
||||||
|
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
|
||||||
|
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
|
||||||
|
im, Sim = bnb.functional.transform(im, "col32")
|
||||||
|
if state.CxB is None:
|
||||||
|
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
|
||||||
|
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
|
||||||
|
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_accelerate_new_hook(old_hook):
|
||||||
|
r"""
|
||||||
|
Creates a new hook based on the old hook. Use it only if you know what you are doing !
|
||||||
|
This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
|
||||||
|
with some changes
|
||||||
|
"""
|
||||||
|
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
|
||||||
|
old_hook_attr = old_hook.__dict__
|
||||||
|
filtered_old_hook_attr = {}
|
||||||
|
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
|
||||||
|
for k in old_hook_attr.keys():
|
||||||
|
if k in old_hook_init_signature.parameters:
|
||||||
|
filtered_old_hook_attr[k] = old_hook_attr[k]
|
||||||
|
new_hook = old_hook_cls(**filtered_old_hook_attr)
|
||||||
|
return new_hook
|
||||||
|
|
||||||
|
|
||||||
|
def _dequantize_and_replace(
|
||||||
|
model,
|
||||||
|
modules_to_not_convert=None,
|
||||||
|
current_key_name=None,
|
||||||
|
quantization_config=None,
|
||||||
|
has_been_replaced=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Converts a quantized model into its dequantized original version. The newly converted model will have
|
||||||
|
some performance drop compared to the original model before quantization - use it only for specific usecases
|
||||||
|
such as QLoRA adapters merging.
|
||||||
|
|
||||||
|
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||||
|
"""
|
||||||
|
quant_method = quantization_config.quantization_method()
|
||||||
|
|
||||||
|
target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
|
||||||
|
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if current_key_name is None:
|
||||||
|
current_key_name = []
|
||||||
|
current_key_name.append(name)
|
||||||
|
|
||||||
|
if isinstance(module, target_cls) and name not in modules_to_not_convert:
|
||||||
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
|
current_key_name_str = ".".join(current_key_name)
|
||||||
|
|
||||||
|
if not any(
|
||||||
|
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||||
|
):
|
||||||
|
bias = getattr(module, "bias", None)
|
||||||
|
|
||||||
|
device = module.weight.device
|
||||||
|
with init_empty_weights():
|
||||||
|
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
|
||||||
|
|
||||||
|
if quant_method == "llm_int8":
|
||||||
|
state = module.state
|
||||||
|
else:
|
||||||
|
state = None
|
||||||
|
|
||||||
|
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
new_module.bias = bias
|
||||||
|
|
||||||
|
# Create a new hook and attach it in case we use accelerate
|
||||||
|
if hasattr(module, "_hf_hook"):
|
||||||
|
old_hook = module._hf_hook
|
||||||
|
new_hook = _create_accelerate_new_hook(old_hook)
|
||||||
|
|
||||||
|
remove_hook_from_module(module)
|
||||||
|
add_hook_to_module(new_module, new_hook)
|
||||||
|
|
||||||
|
new_module.to(device)
|
||||||
|
model._modules[name] = new_module
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
_, has_been_replaced = _dequantize_and_replace(
|
||||||
|
module,
|
||||||
|
modules_to_not_convert,
|
||||||
|
current_key_name,
|
||||||
|
quantization_config,
|
||||||
|
has_been_replaced=has_been_replaced,
|
||||||
|
)
|
||||||
|
# Remove the last key for recursion
|
||||||
|
current_key_name.pop(-1)
|
||||||
|
return model, has_been_replaced
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize_and_replace(
|
||||||
|
model,
|
||||||
|
modules_to_not_convert=None,
|
||||||
|
quantization_config=None,
|
||||||
|
):
|
||||||
|
model, has_been_replaced = _dequantize_and_replace(
|
||||||
|
model,
|
||||||
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_been_replaced:
|
||||||
|
logger.warning(
|
||||||
|
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|||||||
@@ -1327,6 +1327,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
self.init_weights()
|
self.init_weights()
|
||||||
self._backward_compatibility_gradient_checkpointing()
|
self._backward_compatibility_gradient_checkpointing()
|
||||||
|
|
||||||
|
def dequantize(self):
|
||||||
|
"""
|
||||||
|
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
||||||
|
dequantization.
|
||||||
|
"""
|
||||||
|
hf_quantizer = getattr(self, "hf_quantizer", None)
|
||||||
|
|
||||||
|
if hf_quantizer is None:
|
||||||
|
raise ValueError("You need to first quantize your model in order to dequantize it")
|
||||||
|
|
||||||
|
return hf_quantizer.dequantize(self)
|
||||||
|
|
||||||
def _backward_compatibility_gradient_checkpointing(self):
|
def _backward_compatibility_gradient_checkpointing(self):
|
||||||
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
|
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
|
||||||
self.gradient_checkpointing_enable()
|
self.gradient_checkpointing_enable()
|
||||||
|
|||||||
@@ -194,6 +194,23 @@ class HfQuantizer(ABC):
|
|||||||
"""
|
"""
|
||||||
return self._process_model_after_weight_loading(model, **kwargs)
|
return self._process_model_after_weight_loading(model, **kwargs)
|
||||||
|
|
||||||
|
def dequantize(self, model):
|
||||||
|
"""
|
||||||
|
Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance.
|
||||||
|
Note not all quantization schemes support this.
|
||||||
|
"""
|
||||||
|
model = self._dequantize(model)
|
||||||
|
|
||||||
|
# Delete quantizer and quantization config
|
||||||
|
del model.hf_quantizer
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _dequantize(self, model):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _process_model_before_weight_loading(self, model, **kwargs):
|
def _process_model_before_weight_loading(self, model, **kwargs):
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -312,3 +312,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|||||||
@property
|
@property
|
||||||
def is_trainable(self) -> bool:
|
def is_trainable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _dequantize(self, model):
|
||||||
|
from ..integrations import dequantize_and_replace
|
||||||
|
|
||||||
|
model = dequantize_and_replace(
|
||||||
|
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|||||||
@@ -281,3 +281,11 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
|||||||
@property
|
@property
|
||||||
def is_trainable(self) -> bool:
|
def is_trainable(self) -> bool:
|
||||||
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0")
|
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0")
|
||||||
|
|
||||||
|
def _dequantize(self, model):
|
||||||
|
from ..integrations import dequantize_and_replace
|
||||||
|
|
||||||
|
model = dequantize_and_replace(
|
||||||
|
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|||||||
@@ -239,6 +239,23 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
|
|
||||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
|
def test_generate_quality_dequantize(self):
|
||||||
|
r"""
|
||||||
|
Test that loading the model and unquantize it produce correct results
|
||||||
|
"""
|
||||||
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
|
||||||
|
model_4bit = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name, quantization_config=bnb_config, device_map="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_4bit.dequantize()
|
||||||
|
|
||||||
|
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||||
|
output_sequences = model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||||
|
|
||||||
|
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
def test_device_and_dtype_assignment(self):
|
def test_device_and_dtype_assignment(self):
|
||||||
r"""
|
r"""
|
||||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||||
|
|||||||
@@ -285,6 +285,23 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
|
|
||||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
|
def test_generate_quality_dequantize(self):
|
||||||
|
r"""
|
||||||
|
Test that loading the model and dequantizing it produce correct results
|
||||||
|
"""
|
||||||
|
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
|
||||||
|
model_8bit = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name, quantization_config=bnb_config, device_map="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_8bit.dequantize()
|
||||||
|
|
||||||
|
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||||
|
output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||||
|
|
||||||
|
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
def test_raise_if_config_and_load_in_8bit(self):
|
def test_raise_if_config_and_load_in_8bit(self):
|
||||||
r"""
|
r"""
|
||||||
Test that loading the model with the config and `load_in_8bit` raises an error
|
Test that loading the model with the config and `load_in_8bit` raises an error
|
||||||
|
|||||||
Reference in New Issue
Block a user