FEAT : Adding BitNet quantization method to HFQuantizer (#33410)

* rebasing changes

* fixing style

* adding some doc to functions

* remove bitblas

* change dtype

* fixing check_code_quality

* fixing import order

* adding doc to tree

* Small update on BitLinear

* adding some tests

* sorting imports

* small update

* reformatting

* reformatting

* reformatting with ruff

* adding assert

* changes after review

* update disk offloading

* adapting after review

* Update after review

* add is_serializable back

* fixing style

* adding serialization test

* make style

* small updates after review
This commit is contained in:
Mohamed Mekkouri
2024-10-09 17:51:41 +02:00
committed by GitHub
parent 48461c0fe2
commit 36d410dab6
11 changed files with 745 additions and 1 deletions

View File

@@ -179,6 +179,8 @@
title: Optimum
- local: quantization/torchao
title: TorchAO
- local: quantization/bitnet
title: BitNet
- local: quantization/compressed_tensors
title: compressed-tensors
- local: quantization/contribute

View File

@@ -68,3 +68,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## TorchAoConfig
[[autodoc]] TorchAoConfig
## BitNetConfig
[[autodoc]] BitNetConfig

View File

@@ -0,0 +1,75 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# BitNet
[BitNet](https://arxiv.org/abs/2402.17764) replaces traditional Linear layers in Multi-Head Attention and Feed-Forward Networks with specialized layers called BitLinear with ternary (or binary in the older version) precision. The BitLinear layers introduced here quantize the weights using ternary precision (with values of -1, 0, and 1) and quantize the activations to 8-bit precision.
<figure style="text-align: center;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/1.58llm_extreme_quantization/bitlinear.png" alt="Alt Text" />
<figcaption>The architecture of BitNet with BitLinear layers</figcaption>
</figure>
During training, we start by quantizing the weights into ternary values, using symmetric per tensor quantization. First, we compute the average of the absolute values of the weight matrix and use this as a scale. We then divide the weights by the scale, round the values, constrain them between -1 and 1, and finally rescale them to continue in full precision.
$$
scale_w = \frac{1}{\frac{1}{nm} \sum_{ij} |W_{ij}|}
$$
$$
W_q = \text{clamp}_{[-1,1]}(\text{round}(W*scale))
$$
$$
W_{dequantized} = W_q*scale_w
$$
Activations are then quantized to a specified bit-width (e.g., 8-bit) using [absmax](https://arxiv.org/pdf/2208.07339) quantization (symmetric per channel quantization). This involves scaling the activations into a range [128,127[. The quantization formula is:
$$
scale_x = \frac{127}{|X|_{\text{max}, \, \text{dim}=-1}}
$$
$$
X_q = \text{clamp}_{[-128,127]}(\text{round}(X*scale))
$$
$$
X_{dequantized} = X_q * scale_x
$$
To learn more about how we trained, and fine-tuned bitnet models checkout the blogpost [here](https://huggingface.co/blog/1_58_llm_extreme_quantization)
## Load a BitNet Model from the Hub
BitNet models can't be quantized on the fly—they need to be pre-trained or fine-tuned with the quantization applied (it's a Quantization aware training technique). Once trained, these models are already quantized and available as packed versions on the hub.
A quantized model can be load :
```py
from transformers import AutoModelForCausalLM
path = "/path/to/model"
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto")
```
## Pre-training / Fine-tuning a BitNet Model
If you're looking to pre-train or fine-tune your own 1.58-bit model using Nanotron, check out this [PR](https://github.com/huggingface/nanotron/pull/180), all you need to get started is there !
For fine-tuning, you'll need to convert the model from Hugging Face format to Nanotron format (which has some differences). You can find the conversion steps in this [PR](https://github.com/huggingface/nanotron/pull/174).
## Kernels
In our initial version, we chose to use `@torch.compile` to unpack the weights and perform the forward pass. Its very straightforward to implement and delivers significant speed improvements. We plan to integrate additional optimized kernels in future versions.

View File

@@ -968,6 +968,7 @@ _import_structure = {
"utils.quantization_config": [
"AqlmConfig",
"AwqConfig",
"BitNetConfig",
"BitsAndBytesConfig",
"CompressedTensorsConfig",
"EetqConfig",
@@ -5869,6 +5870,7 @@ if TYPE_CHECKING:
from .utils.quantization_config import (
AqlmConfig,
AwqConfig,
BitNetConfig,
BitsAndBytesConfig,
CompressedTensorsConfig,
EetqConfig,

View File

@@ -25,6 +25,12 @@ _import_structure = {
"replace_quantization_scales",
"replace_with_awq_linear",
],
"bitnet": [
"BitLinear",
"pack_weights",
"replace_with_bitnet_linear",
"unpack_weights",
],
"bitsandbytes": [
"dequantize_and_replace",
"get_keys_to_not_convert",
@@ -120,6 +126,12 @@ if TYPE_CHECKING:
replace_quantization_scales,
replace_with_awq_linear,
)
from .bitnet import (
BitLinear,
pack_weights,
replace_with_bitnet_linear,
unpack_weights,
)
from .bitsandbytes import (
dequantize_and_replace,
get_keys_to_not_convert,

View File

@@ -0,0 +1,286 @@
from ..utils import is_accelerate_available, is_torch_available, logging
if is_accelerate_available():
from accelerate import init_empty_weights
if is_torch_available():
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.get_logger(__name__)
# the weights are ternary so can be represented with 2 bits, and they are packed in uint8 tensors, hence the number of values per item is 4
VALUES_PER_ITEM = 4
def pack_weights(quantized_weights: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor of quantized weights into a compact format using 2 bits per value.
Parameters:
-----------
quantized_weights : torch.Tensor
A tensor containing ternary quantized weights with values in {-1, 0, 1}. These values are adjusted to
{0, 1, 2} before being packed.
Returns:
--------
torch.Tensor
A packed tensor where each element stores 4 quantized values (each using 2 bits) in an 8-bit format.
"""
original_shape = quantized_weights.shape
row_dim = (original_shape[0] + VALUES_PER_ITEM - 1) // VALUES_PER_ITEM
if len(original_shape) == 1:
packed_tensor_shape = (row_dim,)
else:
packed_tensor_shape = (row_dim, *original_shape[1:])
quantized_weights += 1
packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8)
unpacked = quantized_weights.to(torch.uint8)
it = min(VALUES_PER_ITEM, (original_shape[0] // row_dim) + 1)
for i in range(it):
start = i * row_dim
end = min(start + row_dim, original_shape[0])
packed[: (end - start)] |= unpacked[start:end] << 2 * i
return packed
@torch.compile
def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""
Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value.
Parameters:
-----------
packed : torch.Tensor
A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value).
dtype : torch.dtype
The dtype of the returned Tensor
Returns:
--------
torch.Tensor
A tensor of unpacked weights, where each value is converted from its packed 2-bit representation.
Example:
--------
packed = torch.tensor([[0b10100001, 0b00011000],
[0b10010000, 0b00001010]], dtype=torch.uint8)
# Unpack the values
unpacked = unpack_weights(packed)
# Resulting unpacked tensor
print(unpacked)
# Output: tensor([[ 0, -1],
[-1, 1],
[-1, 1],
[-1, 1],
[ 1, 0],
[ 0, -1],
[ 1, -1],
[ 1, -1]])
Explanation of the example:
---------------------------
Let's take the first value for example 0b10100001, we we will only focus on the first column,
because every element is unpacked across the first dimension
- First 2 bits: `01` → 0 at [0][0]
- Second 2 bits: `00` → -1 at [0][2]
- Third 2 bits: `10` → 1 at [0][4]
- Fourth 2 bits: `10` → 1 at [0][6]
the second value of the same row (0b10010000) will give the values for [0][1], [0][3], [0][5], [0][7]
We subtract 1 because during the packing process, it's easier to work with values like 0, 1, and 2. To make this possible,
we add 1 to the original ternary weights (which are typically -1, 0, and 1) when packing them. When unpacking, we reverse
this by subtracting 1 to restore the original ternary values.
"""
packed_shape = packed.shape
if len(packed_shape) == 1:
original_row_dim = packed_shape[0] * VALUES_PER_ITEM
unpacked_shape = (original_row_dim,)
else:
original_row_dim = packed_shape[0] * VALUES_PER_ITEM
unpacked_shape = (original_row_dim, *packed_shape[1:])
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
for i in range(VALUES_PER_ITEM):
start = i * packed_shape[0]
end = start + packed_shape[0]
mask = 3 << (2 * i)
unpacked[start:end] = (packed & mask) >> (2 * i)
return unpacked.to(dtype) - 1
class BitLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None):
super().__init__()
self.dtype = dtype
self.register_buffer(
"weight",
torch.zeros(
(out_features // VALUES_PER_ITEM, in_features),
dtype=torch.uint8,
device=device,
),
)
self.register_buffer(
"weight_scale",
torch.ones(
(1),
dtype=dtype,
device=device,
),
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
else:
self.bias = None
@torch.compile
def activation_quant(self, input, num_bits=8):
"""
Activation function : Performs symmetric, per-token quantization on the input activations.
Parameters:
-----------
x : torch.Tensor
Input activations to be quantized.
num_bits : int, optional (default=8)
Number of bits to use for quantization, determining the quantization range.
Returns:
--------
result : torch.Tensor
Quantized activation tensor, with values mapped to an `int8` range.
scale : torch.Tensor
The per-channel scaling factors used to quantize the tensor.
"""
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
scale = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (input * scale).round().clamp(Qn, Qp)
return result.to(torch.int8), scale
@torch.compile
def post_quant_process(self, input, input_scale, weight_scale):
out = input / (input_scale * weight_scale)
return out
def forward(self, input):
w = self.weight
w_quant = unpack_weights(w, dtype=self.dtype)
input_quant, input_scale = self.activation_quant(input)
y = F.linear(input_quant.to(self.dtype), w_quant)
y = self.post_quant_process(y, self.weight_scale, input_scale)
if self.bias is not None:
y += self.bias.view(1, -1).expand_as(y)
return y
def _replace_with_bitnet_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
if current_key_name is None:
current_key_name = []
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
in_features = module.in_features
out_features = module.out_features
model._modules[name] = BitLinear(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
)
has_been_replaced = True
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bitnet_linear(
module,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_bitnet_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
pre_quantized=False,
):
"""
A helper function to replace all `torch.nn.Linear` modules by `BitLinear158` modules`.
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config and quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_bitnet_linear(
model,
modules_to_not_convert,
current_key_name,
quantization_config,
pre_quantized=pre_quantized,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using bitnet but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model

View File

@@ -18,6 +18,7 @@ from ..models.auto.configuration_auto import AutoConfig
from ..utils.quantization_config import (
AqlmConfig,
AwqConfig,
BitNetConfig,
BitsAndBytesConfig,
CompressedTensorsConfig,
EetqConfig,
@@ -31,6 +32,7 @@ from ..utils.quantization_config import (
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bitnet import BitNetHfQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
@@ -54,6 +56,7 @@ AUTO_QUANTIZER_MAPPING = {
"compressed-tensors": CompressedTensorsHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
"bitnet": BitNetHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -68,6 +71,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"compressed-tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
"bitnet": BitNetConfig,
}

View File

@@ -0,0 +1,115 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Union
from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class BitNetHfQuantizer(HfQuantizer):
"""
1.58-bit quantization from BitNet quantization method:
Before loading: it converts the linear layers into BitLinear layers during loading.
Checkout the paper introducing this method : https://arxiv.org/pdf/2402.17764
"""
requires_parameters_quantization = False
requires_calibration = True
required_packages = ["accelerate"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
raise ImportError("Loading a BitNet quantized model requires accelerate (`pip install accelerate`)")
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
raise ValueError(
"Loading ternary weights from tf/flax is currently not supported, please make"
" sure the weights are in PyTorch format."
)
if not torch.cuda.is_available():
logger.warning_once(
"You don't have a GPU available to load the model, the inference will be slow because of weight unpacking"
)
return
device_map = kwargs.get("device_map", None)
if device_map is None:
logger.warning_once(
"You have loaded a BitNet model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
elif device_map is not None:
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load a BitNet model with a device_map that contains a CPU or disk device."
"This is not supported. Please remove the CPU or disk device from the device_map."
)
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_bitnet_linear
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
model = replace_with_bitnet_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
pre_quantized=self.pre_quantized,
)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
target_dtype = torch.int8
return target_dtype
def is_serializable(self, safe_serialization=None):
return True
@property
def is_trainable(self) -> bool:
return False

View File

@@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum):
COMPRESSED_TENSORS = "compressed-tensors"
FBGEMM_FP8 = "fbgemm_fp8"
TORCHAO = "torchao"
BITNET = "bitnet"
class AWQLinearVersion(str, Enum):
@@ -1308,4 +1309,22 @@ class TorchAoConfig(QuantizationConfigMixin):
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
def __repr__(self):
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.quant_type_kwargs.items())})"
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"
@dataclass
class BitNetConfig(QuantizationConfigMixin):
def __init__(
self,
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.BITNET
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct
"""
pass

View File

@@ -0,0 +1,225 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitNetConfig,
OPTForCausalLM,
)
from transformers.testing_utils import (
require_accelerate,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class BitNetConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = BitNetConfig()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
@slow
@require_torch_gpu
@require_accelerate
class BitNetTest(unittest.TestCase):
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
device = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Load the model
"""
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=cls.device)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_replace_with_bitlinear(self):
from transformers.integrations import BitLinear, replace_with_bitnet_linear
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id)
with init_empty_weights():
model = OPTForCausalLM(config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model = replace_with_bitnet_linear(model)
nb_bitnet_linear = 0
for module in model.modules():
if isinstance(module, BitLinear):
nb_bitnet_linear += 1
self.assertEqual(nb_linears - 1, nb_bitnet_linear)
def test_quantized_model(self, quantized_model, tokenizer):
"""
Simple test that checks if the quantized model is working properly
"""
input_text = "What are we having for dinner?"
expected_output = "What are we having for dinner? What are we going to do for fun this weekend?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
output = quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
def test_packing_unpacking(self):
"""
Simple test the packing and unpacking logic
"""
from transformers.integrations import pack_weights, unpack_weights
u = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8)
unpacked_u = unpack_weights(u, dtype=torch.bfloat16)
self.assertEqual(pack_weights(unpacked_u), u)
def test_activation_quant(self):
"""
test the activation function behaviour
"""
from transformers.integrations import BitLinear
layer = BitLinear(in_features=4, out_features=2, bias=False, dtype=torch.float32)
layer.to(self.device)
input_tensor = torch.tensor([[1.0, -1.0, -1.0, 1.0], [1.0, -1.0, 1.0, 1.0]], dtype=torch.float32).to(
torch_device
)
# Quantize the input tensor
quantized_tensor, scale = layer.activation_quant(input_tensor)
# Verify the output quantized tensor
self.assertEqual(quantized_tensor, input_tensor)
# Verify the scale tensor
self.assertEqual(scale, 127)
def test_weights_dtype(self):
"""
test the weights dtype after loading
"""
self_attn_q = self.quantized_model.model.layers[0].self_attn.q_proj.weight
self_attn_k = self.quantized_model.model.layers[0].self_attn.k_proj.weight
self_attn_v = self.quantized_model.model.layers[0].self_attn.v_proj.weight
self_attn_o = self.quantized_model.model.layers[0].self_attn.o_proj.weight
mlp_gate = self.quantized_model.model.layers[0].mlp.gate_proj.weight
mlp_up = self.quantized_model.model.layers[0].mlp.up_proj.weight
mlp_down = self.quantized_model.model.layers[0].mlp.down_proj.weight
self.assertEqual(self_attn_q.dtype, torch.uint8)
self.assertEqual(self_attn_k.dtype, torch.uint8)
self.assertEqual(self_attn_v.dtype, torch.uint8)
self.assertEqual(self_attn_o.dtype, torch.uint8)
self.assertEqual(mlp_up.dtype, torch.uint8)
self.assertEqual(mlp_gate.dtype, torch.uint8)
self.assertEqual(mlp_down.dtype, torch.uint8)
def test_replace_with_bitlinear_shape(self):
"""
test that the BitNet layer weight shapes are correct, and the weight_scale is correctly initialized to 1
"""
from transformers.integrations import replace_with_bitnet_linear
out_features = 1024
in_features = 512
class SimpleLinearModule(torch.nn.Module):
"""
Simple class to test BitLinear
"""
def __init__(
self,
in_features: int = in_features,
out_features: int = out_features,
bias: bool = False,
):
super().__init__()
self.linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
def forward(self, x):
return self.linear(x)
model = SimpleLinearModule()
replace_with_bitnet_linear(model)
self.assertEqual(list(model.linear.weight.shape), [out_features // 4, in_features])
self.assertEqual(model.linear.weight_scale, 1)
@slow
@require_torch_gpu
@require_accelerate
class BitNetSerializationTest(unittest.TestCase):
def test_model_serialization(self):
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
device = "cuda"
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=device)
with torch.no_grad():
logits_ref = quantized_model.forward(input_tensor).logits
# Save
saved_model_id = "quant_model"
quantized_model.save_pretrained(saved_model_id)
# Remove old model
del quantized_model
torch.cuda.empty_cache()
# Load and check if the logits match
model_loaded = AutoModelForCausalLM.from_pretrained("quant_model", device_map=device)
with torch.no_grad():
logits_loaded = model_loaded.forward(input_tensor).logits
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)