FEAT : Adding BitNet quantization method to HFQuantizer (#33410)
* rebasing changes * fixing style * adding some doc to functions * remove bitblas * change dtype * fixing check_code_quality * fixing import order * adding doc to tree * Small update on BitLinear * adding some tests * sorting imports * small update * reformatting * reformatting * reformatting with ruff * adding assert * changes after review * update disk offloading * adapting after review * Update after review * add is_serializable back * fixing style * adding serialization test * make style * small updates after review
This commit is contained in:
@@ -179,6 +179,8 @@
|
||||
title: Optimum
|
||||
- local: quantization/torchao
|
||||
title: TorchAO
|
||||
- local: quantization/bitnet
|
||||
title: BitNet
|
||||
- local: quantization/compressed_tensors
|
||||
title: compressed-tensors
|
||||
- local: quantization/contribute
|
||||
|
||||
@@ -68,3 +68,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
## BitNetConfig
|
||||
|
||||
[[autodoc]] BitNetConfig
|
||||
|
||||
75
docs/source/en/quantization/bitnet.md
Normal file
75
docs/source/en/quantization/bitnet.md
Normal file
@@ -0,0 +1,75 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# BitNet
|
||||
|
||||
[BitNet](https://arxiv.org/abs/2402.17764) replaces traditional Linear layers in Multi-Head Attention and Feed-Forward Networks with specialized layers called BitLinear with ternary (or binary in the older version) precision. The BitLinear layers introduced here quantize the weights using ternary precision (with values of -1, 0, and 1) and quantize the activations to 8-bit precision.
|
||||
|
||||
|
||||
<figure style="text-align: center;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/1.58llm_extreme_quantization/bitlinear.png" alt="Alt Text" />
|
||||
<figcaption>The architecture of BitNet with BitLinear layers</figcaption>
|
||||
</figure>
|
||||
|
||||
During training, we start by quantizing the weights into ternary values, using symmetric per tensor quantization. First, we compute the average of the absolute values of the weight matrix and use this as a scale. We then divide the weights by the scale, round the values, constrain them between -1 and 1, and finally rescale them to continue in full precision.
|
||||
|
||||
$$
|
||||
scale_w = \frac{1}{\frac{1}{nm} \sum_{ij} |W_{ij}|}
|
||||
$$
|
||||
|
||||
$$
|
||||
W_q = \text{clamp}_{[-1,1]}(\text{round}(W*scale))
|
||||
$$
|
||||
|
||||
$$
|
||||
W_{dequantized} = W_q*scale_w
|
||||
$$
|
||||
|
||||
Activations are then quantized to a specified bit-width (e.g., 8-bit) using [absmax](https://arxiv.org/pdf/2208.07339) quantization (symmetric per channel quantization). This involves scaling the activations into a range [−128,127[. The quantization formula is:
|
||||
|
||||
$$
|
||||
scale_x = \frac{127}{|X|_{\text{max}, \, \text{dim}=-1}}
|
||||
$$
|
||||
|
||||
$$
|
||||
X_q = \text{clamp}_{[-128,127]}(\text{round}(X*scale))
|
||||
$$
|
||||
|
||||
$$
|
||||
X_{dequantized} = X_q * scale_x
|
||||
$$
|
||||
|
||||
To learn more about how we trained, and fine-tuned bitnet models checkout the blogpost [here](https://huggingface.co/blog/1_58_llm_extreme_quantization)
|
||||
|
||||
## Load a BitNet Model from the Hub
|
||||
BitNet models can't be quantized on the fly—they need to be pre-trained or fine-tuned with the quantization applied (it's a Quantization aware training technique). Once trained, these models are already quantized and available as packed versions on the hub.
|
||||
|
||||
A quantized model can be load :
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM
|
||||
path = "/path/to/model"
|
||||
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto")
|
||||
```
|
||||
## Pre-training / Fine-tuning a BitNet Model
|
||||
|
||||
If you're looking to pre-train or fine-tune your own 1.58-bit model using Nanotron, check out this [PR](https://github.com/huggingface/nanotron/pull/180), all you need to get started is there !
|
||||
|
||||
For fine-tuning, you'll need to convert the model from Hugging Face format to Nanotron format (which has some differences). You can find the conversion steps in this [PR](https://github.com/huggingface/nanotron/pull/174).
|
||||
|
||||
## Kernels
|
||||
|
||||
In our initial version, we chose to use `@torch.compile` to unpack the weights and perform the forward pass. It’s very straightforward to implement and delivers significant speed improvements. We plan to integrate additional optimized kernels in future versions.
|
||||
@@ -968,6 +968,7 @@ _import_structure = {
|
||||
"utils.quantization_config": [
|
||||
"AqlmConfig",
|
||||
"AwqConfig",
|
||||
"BitNetConfig",
|
||||
"BitsAndBytesConfig",
|
||||
"CompressedTensorsConfig",
|
||||
"EetqConfig",
|
||||
@@ -5869,6 +5870,7 @@ if TYPE_CHECKING:
|
||||
from .utils.quantization_config import (
|
||||
AqlmConfig,
|
||||
AwqConfig,
|
||||
BitNetConfig,
|
||||
BitsAndBytesConfig,
|
||||
CompressedTensorsConfig,
|
||||
EetqConfig,
|
||||
|
||||
@@ -25,6 +25,12 @@ _import_structure = {
|
||||
"replace_quantization_scales",
|
||||
"replace_with_awq_linear",
|
||||
],
|
||||
"bitnet": [
|
||||
"BitLinear",
|
||||
"pack_weights",
|
||||
"replace_with_bitnet_linear",
|
||||
"unpack_weights",
|
||||
],
|
||||
"bitsandbytes": [
|
||||
"dequantize_and_replace",
|
||||
"get_keys_to_not_convert",
|
||||
@@ -120,6 +126,12 @@ if TYPE_CHECKING:
|
||||
replace_quantization_scales,
|
||||
replace_with_awq_linear,
|
||||
)
|
||||
from .bitnet import (
|
||||
BitLinear,
|
||||
pack_weights,
|
||||
replace_with_bitnet_linear,
|
||||
unpack_weights,
|
||||
)
|
||||
from .bitsandbytes import (
|
||||
dequantize_and_replace,
|
||||
get_keys_to_not_convert,
|
||||
|
||||
286
src/transformers/integrations/bitnet.py
Normal file
286
src/transformers/integrations/bitnet.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# the weights are ternary so can be represented with 2 bits, and they are packed in uint8 tensors, hence the number of values per item is 4
|
||||
VALUES_PER_ITEM = 4
|
||||
|
||||
|
||||
def pack_weights(quantized_weights: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Packs a tensor of quantized weights into a compact format using 2 bits per value.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
quantized_weights : torch.Tensor
|
||||
A tensor containing ternary quantized weights with values in {-1, 0, 1}. These values are adjusted to
|
||||
{0, 1, 2} before being packed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
A packed tensor where each element stores 4 quantized values (each using 2 bits) in an 8-bit format.
|
||||
"""
|
||||
|
||||
original_shape = quantized_weights.shape
|
||||
|
||||
row_dim = (original_shape[0] + VALUES_PER_ITEM - 1) // VALUES_PER_ITEM
|
||||
|
||||
if len(original_shape) == 1:
|
||||
packed_tensor_shape = (row_dim,)
|
||||
else:
|
||||
packed_tensor_shape = (row_dim, *original_shape[1:])
|
||||
|
||||
quantized_weights += 1
|
||||
packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8)
|
||||
unpacked = quantized_weights.to(torch.uint8)
|
||||
|
||||
it = min(VALUES_PER_ITEM, (original_shape[0] // row_dim) + 1)
|
||||
for i in range(it):
|
||||
start = i * row_dim
|
||||
end = min(start + row_dim, original_shape[0])
|
||||
packed[: (end - start)] |= unpacked[start:end] << 2 * i
|
||||
|
||||
return packed
|
||||
|
||||
|
||||
@torch.compile
|
||||
def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
packed : torch.Tensor
|
||||
A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value).
|
||||
dtype : torch.dtype
|
||||
The dtype of the returned Tensor
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
A tensor of unpacked weights, where each value is converted from its packed 2-bit representation.
|
||||
|
||||
Example:
|
||||
--------
|
||||
packed = torch.tensor([[0b10100001, 0b00011000],
|
||||
[0b10010000, 0b00001010]], dtype=torch.uint8)
|
||||
|
||||
# Unpack the values
|
||||
unpacked = unpack_weights(packed)
|
||||
|
||||
# Resulting unpacked tensor
|
||||
print(unpacked)
|
||||
# Output: tensor([[ 0, -1],
|
||||
[-1, 1],
|
||||
[-1, 1],
|
||||
[-1, 1],
|
||||
[ 1, 0],
|
||||
[ 0, -1],
|
||||
[ 1, -1],
|
||||
[ 1, -1]])
|
||||
|
||||
Explanation of the example:
|
||||
---------------------------
|
||||
Let's take the first value for example 0b10100001, we we will only focus on the first column,
|
||||
because every element is unpacked across the first dimension
|
||||
- First 2 bits: `01` → 0 at [0][0]
|
||||
- Second 2 bits: `00` → -1 at [0][2]
|
||||
- Third 2 bits: `10` → 1 at [0][4]
|
||||
- Fourth 2 bits: `10` → 1 at [0][6]
|
||||
the second value of the same row (0b10010000) will give the values for [0][1], [0][3], [0][5], [0][7]
|
||||
|
||||
We subtract 1 because during the packing process, it's easier to work with values like 0, 1, and 2. To make this possible,
|
||||
we add 1 to the original ternary weights (which are typically -1, 0, and 1) when packing them. When unpacking, we reverse
|
||||
this by subtracting 1 to restore the original ternary values.
|
||||
"""
|
||||
packed_shape = packed.shape
|
||||
|
||||
if len(packed_shape) == 1:
|
||||
original_row_dim = packed_shape[0] * VALUES_PER_ITEM
|
||||
unpacked_shape = (original_row_dim,)
|
||||
else:
|
||||
original_row_dim = packed_shape[0] * VALUES_PER_ITEM
|
||||
unpacked_shape = (original_row_dim, *packed_shape[1:])
|
||||
|
||||
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
|
||||
|
||||
for i in range(VALUES_PER_ITEM):
|
||||
start = i * packed_shape[0]
|
||||
end = start + packed_shape[0]
|
||||
mask = 3 << (2 * i)
|
||||
unpacked[start:end] = (packed & mask) >> (2 * i)
|
||||
|
||||
return unpacked.to(dtype) - 1
|
||||
|
||||
|
||||
class BitLinear(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.zeros(
|
||||
(out_features // VALUES_PER_ITEM, in_features),
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"weight_scale",
|
||||
torch.ones(
|
||||
(1),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
if bias:
|
||||
self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@torch.compile
|
||||
def activation_quant(self, input, num_bits=8):
|
||||
"""
|
||||
Activation function : Performs symmetric, per-token quantization on the input activations.
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.Tensor
|
||||
Input activations to be quantized.
|
||||
num_bits : int, optional (default=8)
|
||||
Number of bits to use for quantization, determining the quantization range.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
result : torch.Tensor
|
||||
Quantized activation tensor, with values mapped to an `int8` range.
|
||||
scale : torch.Tensor
|
||||
The per-channel scaling factors used to quantize the tensor.
|
||||
"""
|
||||
Qn = -(2 ** (num_bits - 1))
|
||||
Qp = 2 ** (num_bits - 1) - 1
|
||||
scale = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
|
||||
result = (input * scale).round().clamp(Qn, Qp)
|
||||
return result.to(torch.int8), scale
|
||||
|
||||
@torch.compile
|
||||
def post_quant_process(self, input, input_scale, weight_scale):
|
||||
out = input / (input_scale * weight_scale)
|
||||
return out
|
||||
|
||||
def forward(self, input):
|
||||
w = self.weight
|
||||
w_quant = unpack_weights(w, dtype=self.dtype)
|
||||
input_quant, input_scale = self.activation_quant(input)
|
||||
y = F.linear(input_quant.to(self.dtype), w_quant)
|
||||
y = self.post_quant_process(y, self.weight_scale, input_scale)
|
||||
if self.bias is not None:
|
||||
y += self.bias.view(1, -1).expand_as(y)
|
||||
return y
|
||||
|
||||
|
||||
def _replace_with_bitnet_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
pre_quantized=False,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
model._modules[name] = BitLinear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
)
|
||||
has_been_replaced = True
|
||||
model._modules[name].requires_grad_(False)
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_bitnet_linear(
|
||||
module,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
current_key_name=current_key_name,
|
||||
quantization_config=quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_bitnet_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
pre_quantized=False,
|
||||
):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `BitLinear158` modules`.
|
||||
|
||||
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
|
||||
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
|
||||
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
|
||||
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
|
||||
for numerical stability reasons.
|
||||
current_key_name (`List[`str`]`, *optional*):
|
||||
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
|
||||
`disk`).
|
||||
"""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
if quantization_config and quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||
model, has_been_replaced = _replace_with_bitnet_linear(
|
||||
model,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
pre_quantized=pre_quantized,
|
||||
)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model using bitnet but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -18,6 +18,7 @@ from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..utils.quantization_config import (
|
||||
AqlmConfig,
|
||||
AwqConfig,
|
||||
BitNetConfig,
|
||||
BitsAndBytesConfig,
|
||||
CompressedTensorsConfig,
|
||||
EetqConfig,
|
||||
@@ -31,6 +32,7 @@ from ..utils.quantization_config import (
|
||||
)
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
from .quantizer_bitnet import BitNetHfQuantizer
|
||||
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
|
||||
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
|
||||
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
|
||||
@@ -54,6 +56,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"compressed-tensors": CompressedTensorsHfQuantizer,
|
||||
"fbgemm_fp8": FbgemmFp8HfQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
"bitnet": BitNetHfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@@ -68,6 +71,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"fbgemm_fp8": FbgemmFp8Config,
|
||||
"torchao": TorchAoConfig,
|
||||
"bitnet": BitNetConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
115
src/transformers/quantizers/quantizer_bitnet.py
Normal file
115
src/transformers/quantizers/quantizer_bitnet.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BitNetHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
1.58-bit quantization from BitNet quantization method:
|
||||
Before loading: it converts the linear layers into BitLinear layers during loading.
|
||||
|
||||
Checkout the paper introducing this method : https://arxiv.org/pdf/2402.17764
|
||||
"""
|
||||
|
||||
requires_parameters_quantization = False
|
||||
requires_calibration = True
|
||||
|
||||
required_packages = ["accelerate"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Loading a BitNet quantized model requires accelerate (`pip install accelerate`)")
|
||||
|
||||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
||||
raise ValueError(
|
||||
"Loading ternary weights from tf/flax is currently not supported, please make"
|
||||
" sure the weights are in PyTorch format."
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You don't have a GPU available to load the model, the inference will be slow because of weight unpacking"
|
||||
)
|
||||
return
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if device_map is None:
|
||||
logger.warning_once(
|
||||
"You have loaded a BitNet model on CPU and have a CUDA device available, make sure to set "
|
||||
"your model on a GPU device in order to run your model."
|
||||
)
|
||||
elif device_map is not None:
|
||||
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
||||
raise ValueError(
|
||||
"You are attempting to load a BitNet model with a device_map that contains a CPU or disk device."
|
||||
"This is not supported. Please remove the CPU or disk device from the device_map."
|
||||
)
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_bitnet_linear
|
||||
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if self.quantization_config.modules_to_not_convert is not None:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
|
||||
model = replace_with_bitnet_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
pre_quantized=self.pre_quantized,
|
||||
)
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
target_dtype = torch.int8
|
||||
return target_dtype
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False
|
||||
@@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum):
|
||||
COMPRESSED_TENSORS = "compressed-tensors"
|
||||
FBGEMM_FP8 = "fbgemm_fp8"
|
||||
TORCHAO = "torchao"
|
||||
BITNET = "bitnet"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
@@ -1308,4 +1309,22 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.quant_type_kwargs.items())})"
|
||||
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BitNetConfig(QuantizationConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.BITNET
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
pass
|
||||
|
||||
0
tests/quantization/bitnet_integration/__init__.py
Normal file
0
tests/quantization/bitnet_integration/__init__.py
Normal file
225
tests/quantization/bitnet_integration/test_bitnet.py
Normal file
225
tests/quantization/bitnet_integration/test_bitnet.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitNetConfig,
|
||||
OPTForCausalLM,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class BitNetConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
|
||||
"""
|
||||
quantization_config = BitNetConfig()
|
||||
config_to_dict = quantization_config.to_dict()
|
||||
|
||||
for key in config_to_dict:
|
||||
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_accelerate
|
||||
class BitNetTest(unittest.TestCase):
|
||||
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
|
||||
device = "cuda"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Load the model
|
||||
"""
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=cls.device)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_replace_with_bitlinear(self):
|
||||
from transformers.integrations import BitLinear, replace_with_bitnet_linear
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
|
||||
nb_linears = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
nb_linears += 1
|
||||
|
||||
model = replace_with_bitnet_linear(model)
|
||||
nb_bitnet_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, BitLinear):
|
||||
nb_bitnet_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 1, nb_bitnet_linear)
|
||||
|
||||
def test_quantized_model(self, quantized_model, tokenizer):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_text = "What are we having for dinner?"
|
||||
expected_output = "What are we having for dinner? What are we going to do for fun this weekend?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||
|
||||
def test_packing_unpacking(self):
|
||||
"""
|
||||
Simple test the packing and unpacking logic
|
||||
"""
|
||||
|
||||
from transformers.integrations import pack_weights, unpack_weights
|
||||
|
||||
u = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8)
|
||||
unpacked_u = unpack_weights(u, dtype=torch.bfloat16)
|
||||
self.assertEqual(pack_weights(unpacked_u), u)
|
||||
|
||||
def test_activation_quant(self):
|
||||
"""
|
||||
test the activation function behaviour
|
||||
"""
|
||||
|
||||
from transformers.integrations import BitLinear
|
||||
|
||||
layer = BitLinear(in_features=4, out_features=2, bias=False, dtype=torch.float32)
|
||||
layer.to(self.device)
|
||||
|
||||
input_tensor = torch.tensor([[1.0, -1.0, -1.0, 1.0], [1.0, -1.0, 1.0, 1.0]], dtype=torch.float32).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# Quantize the input tensor
|
||||
quantized_tensor, scale = layer.activation_quant(input_tensor)
|
||||
|
||||
# Verify the output quantized tensor
|
||||
self.assertEqual(quantized_tensor, input_tensor)
|
||||
|
||||
# Verify the scale tensor
|
||||
self.assertEqual(scale, 127)
|
||||
|
||||
def test_weights_dtype(self):
|
||||
"""
|
||||
test the weights dtype after loading
|
||||
"""
|
||||
|
||||
self_attn_q = self.quantized_model.model.layers[0].self_attn.q_proj.weight
|
||||
self_attn_k = self.quantized_model.model.layers[0].self_attn.k_proj.weight
|
||||
self_attn_v = self.quantized_model.model.layers[0].self_attn.v_proj.weight
|
||||
self_attn_o = self.quantized_model.model.layers[0].self_attn.o_proj.weight
|
||||
mlp_gate = self.quantized_model.model.layers[0].mlp.gate_proj.weight
|
||||
mlp_up = self.quantized_model.model.layers[0].mlp.up_proj.weight
|
||||
mlp_down = self.quantized_model.model.layers[0].mlp.down_proj.weight
|
||||
|
||||
self.assertEqual(self_attn_q.dtype, torch.uint8)
|
||||
self.assertEqual(self_attn_k.dtype, torch.uint8)
|
||||
self.assertEqual(self_attn_v.dtype, torch.uint8)
|
||||
self.assertEqual(self_attn_o.dtype, torch.uint8)
|
||||
self.assertEqual(mlp_up.dtype, torch.uint8)
|
||||
self.assertEqual(mlp_gate.dtype, torch.uint8)
|
||||
self.assertEqual(mlp_down.dtype, torch.uint8)
|
||||
|
||||
def test_replace_with_bitlinear_shape(self):
|
||||
"""
|
||||
test that the BitNet layer weight shapes are correct, and the weight_scale is correctly initialized to 1
|
||||
"""
|
||||
|
||||
from transformers.integrations import replace_with_bitnet_linear
|
||||
|
||||
out_features = 1024
|
||||
in_features = 512
|
||||
|
||||
class SimpleLinearModule(torch.nn.Module):
|
||||
"""
|
||||
Simple class to test BitLinear
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int = in_features,
|
||||
out_features: int = out_features,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
model = SimpleLinearModule()
|
||||
replace_with_bitnet_linear(model)
|
||||
|
||||
self.assertEqual(list(model.linear.weight.shape), [out_features // 4, in_features])
|
||||
self.assertEqual(model.linear.weight_scale, 1)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_accelerate
|
||||
class BitNetSerializationTest(unittest.TestCase):
|
||||
def test_model_serialization(self):
|
||||
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
|
||||
device = "cuda"
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
|
||||
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_ref = quantized_model.forward(input_tensor).logits
|
||||
|
||||
# Save
|
||||
saved_model_id = "quant_model"
|
||||
quantized_model.save_pretrained(saved_model_id)
|
||||
|
||||
# Remove old model
|
||||
del quantized_model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load and check if the logits match
|
||||
model_loaded = AutoModelForCausalLM.from_pretrained("quant_model", device_map=device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||
|
||||
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
|
||||
Reference in New Issue
Block a user