[bnb] Let's make serialization of 4bit models possible (#26037)
* updated bitsandbytes.py * rm test_raise_* from test_4bit.py * add test_4bit_serialization.py * modeling_utils bulk edits * bnb_ver 0.41.3 in integrations/bitsandbytes.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * @slow reinstated Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * bnb ver 0.41.3 in src/transformers/modeling_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * rm bnb version todo in integrations/bitsandbytes.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * moved 4b serialization tests to test_4bit * tests upd for opt * to torch_device Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * ruff fixes to tests * rm redundant bnb version check in mod_utils Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * restore _hf_peft_config_loaded modeling_utils.py::2188 Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * restore _hf_peft_config_loaded test in modeling_utils.py::2199 Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fixed NOT getattr(self, "is_8bit_serializable") Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * setting model.is_4bit_serializable * rm separate fp16_statistics arg from set_module... * rm else branch in integrations::bnb::set_module * bnb 4bit dtype check * upd comment on 4bit weights * upd tests for FP4 safe --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,7 @@ if is_accelerate_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
|
||||
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
|
||||
"""
|
||||
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
||||
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
|
||||
@@ -37,8 +37,8 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
||||
The device on which to set the tensor.
|
||||
value (`torch.Tensor`, *optional*):
|
||||
The value of the tensor (useful when going from the meta device to any other device).
|
||||
fp16_statistics (`torch.HalfTensor`, *optional*):
|
||||
The list of fp16 statistics to set on the module, used for serialization.
|
||||
quantized_stats (`dict[str, Any]`, *optional*):
|
||||
Dict with items for either 4-bit or 8-bit serialization
|
||||
"""
|
||||
# Recurse if needed
|
||||
if "." in tensor_name:
|
||||
@@ -58,8 +58,7 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
||||
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
|
||||
|
||||
is_4bit = False
|
||||
is_8bit = False
|
||||
prequantized_loading = quantized_stats is not None
|
||||
if is_buffer or not is_bitsandbytes_available():
|
||||
is_8bit = False
|
||||
is_4bit = False
|
||||
@@ -74,32 +73,53 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
||||
new_value = old_value.to(device)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
new_value = value.to("cpu")
|
||||
if value.dtype == torch.int8:
|
||||
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
|
||||
"0.37.2"
|
||||
)
|
||||
if not is_8bit_serializable:
|
||||
raise ValueError(
|
||||
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
|
||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||
)
|
||||
else:
|
||||
new_value = torch.tensor(value, device="cpu")
|
||||
|
||||
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
|
||||
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
||||
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
|
||||
if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
|
||||
new_value = new_value.T
|
||||
|
||||
kwargs = old_value.__dict__
|
||||
if is_8bit:
|
||||
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
|
||||
elif is_4bit:
|
||||
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
|
||||
|
||||
if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
|
||||
raise ValueError(
|
||||
f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
|
||||
)
|
||||
|
||||
if is_8bit:
|
||||
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
|
||||
"0.37.2"
|
||||
)
|
||||
if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
|
||||
raise ValueError(
|
||||
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
|
||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||
)
|
||||
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
|
||||
if prequantized_loading:
|
||||
setattr(new_value, "SCB", quantized_stats["SCB"].to(device))
|
||||
elif is_4bit:
|
||||
if prequantized_loading:
|
||||
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
|
||||
"0.41.3"
|
||||
)
|
||||
if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable:
|
||||
raise ValueError(
|
||||
"Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
|
||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||
)
|
||||
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||
data=new_value,
|
||||
quantized_stats=quantized_stats,
|
||||
requires_grad=False,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
|
||||
module._parameters[tensor_name] = new_value
|
||||
if fp16_statistics is not None:
|
||||
setattr(module.weight, "SCB", fp16_statistics.to(device))
|
||||
|
||||
else:
|
||||
if value is None:
|
||||
@@ -117,7 +137,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
||||
|
||||
|
||||
def _replace_with_bnb_linear(
|
||||
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
@@ -675,6 +675,7 @@ def _load_state_dict_into_meta_model(
|
||||
is_quantized=False,
|
||||
is_safetensors=False,
|
||||
keep_in_fp32_modules=None,
|
||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
@@ -776,17 +777,41 @@ def _load_state_dict_into_meta_model(
|
||||
elif not is_quantized:
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
||||
else:
|
||||
if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
|
||||
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
|
||||
else:
|
||||
fp16_statistics = None
|
||||
elif param.dtype in (torch.int8, torch.uint8) and is_quantized:
|
||||
# handling newly quantized weights and loaded quantized weights
|
||||
# edit the param.dtype restrictions and is_quantized condition when adding new quant methods
|
||||
quantized_stats = {}
|
||||
|
||||
if (param_name + ".quant_state.bitsandbytes__fp4" in state_dict) or (
|
||||
param_name + ".quant_state.bitsandbytes__nf4" in state_dict
|
||||
):
|
||||
# 4bit loading. Collecting components for restoring quantized weight
|
||||
# This can be expanded to make a universal call for any quantized weight loading
|
||||
for k, v in state_dict.items():
|
||||
if param_name + "." in k:
|
||||
quantized_stats[k] = v
|
||||
unexpected_keys.remove(k)
|
||||
|
||||
if "SCB" not in param_name:
|
||||
set_module_quantized_tensor_to_device(
|
||||
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
|
||||
model, param_name, param_device, value=param, quantized_stats=quantized_stats
|
||||
)
|
||||
|
||||
elif param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
|
||||
# 8bit loading. Could be combined with the above 4bit call.
|
||||
# condition looks unreliable
|
||||
fp16_statistics_key = param_name.replace("weight", "SCB")
|
||||
unexpected_keys.remove(fp16_statistics_key)
|
||||
set_module_quantized_tensor_to_device(
|
||||
model,
|
||||
param_name,
|
||||
param_device,
|
||||
value=param,
|
||||
quantized_stats={"SCB": state_dict[fp16_statistics_key]},
|
||||
)
|
||||
else:
|
||||
# loading not quantized params in quantized model
|
||||
set_module_quantized_tensor_to_device(model, param_name, param_device, value=param)
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
|
||||
@@ -2197,15 +2222,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
and not getattr(self, "is_8bit_serializable", False)
|
||||
and not _hf_peft_config_loaded
|
||||
):
|
||||
raise ValueError(
|
||||
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
|
||||
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
|
||||
raise NotImplementedError(
|
||||
"You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
|
||||
"If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
|
||||
)
|
||||
|
||||
# If the model has adapters attached, you can save the adapters
|
||||
if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded:
|
||||
if (
|
||||
getattr(self, "is_loaded_in_4bit", False)
|
||||
and not getattr(self, "is_4bit_serializable", False)
|
||||
and not _hf_peft_config_loaded
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
|
||||
"You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
|
||||
"If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
|
||||
)
|
||||
|
||||
if getattr(self, "_awq_is_fused", False):
|
||||
@@ -2774,8 +2803,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
use_safetensors = False
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
|
||||
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")
|
||||
else:
|
||||
is_4bit_serializable = False
|
||||
is_8bit_serializable = False
|
||||
|
||||
if trust_remote_code is True:
|
||||
@@ -3064,10 +3095,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
if (
|
||||
is_8bit_serializable
|
||||
and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES
|
||||
and load_in_8bit
|
||||
if quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and (
|
||||
(is_8bit_serializable and load_in_8bit) or (is_4bit_serializable and load_in_4bit)
|
||||
):
|
||||
if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
|
||||
logger.warning(
|
||||
@@ -3077,8 +3106,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
config.quantization_config = quantization_config
|
||||
elif (
|
||||
is_8bit_serializable
|
||||
and not load_in_8bit
|
||||
(is_8bit_serializable or is_4bit_serializable)
|
||||
and not (load_in_8bit or load_in_4bit)
|
||||
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
|
||||
):
|
||||
quantization_config = config.quantization_config
|
||||
@@ -3093,8 +3122,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
load_in_8bit = quantization_config.load_in_8bit
|
||||
load_in_4bit = quantization_config.load_in_4bit
|
||||
|
||||
if load_in_8bit:
|
||||
if load_in_8bit or load_in_4bit:
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.float16
|
||||
if device_map is None:
|
||||
@@ -3112,12 +3142,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
elif (
|
||||
not is_8bit_serializable
|
||||
and not load_in_8bit
|
||||
and not (load_in_8bit or load_in_4bit)
|
||||
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
|
||||
):
|
||||
logger.warning(
|
||||
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
|
||||
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
|
||||
" `bitsandbytes` version to support 4 and 8 bit serialization. Please install the latest version of `bitsandbytes` with "
|
||||
" `pip install --upgrade bitsandbytes`."
|
||||
)
|
||||
|
||||
@@ -3525,6 +3555,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
config.quantization_config = quantization_config
|
||||
model.is_8bit_serializable = is_8bit_serializable
|
||||
model.is_4bit_serializable = is_4bit_serializable
|
||||
|
||||
if load_in_8bit and torch_dtype is None:
|
||||
logger.warning(
|
||||
@@ -4018,10 +4049,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
if (
|
||||
state_dict[checkpoint_key].shape[-1] == 1
|
||||
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
|
||||
):
|
||||
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
|
||||
pass
|
||||
else:
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
if resolved_archive_file is not None:
|
||||
@@ -4130,6 +4169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
is_quantized=is_quantized,
|
||||
is_safetensors=is_safetensors,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
else:
|
||||
@@ -4167,10 +4207,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
|
||||
if is_quantized:
|
||||
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
|
||||
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
archs = [] if model.config.architectures is None else model.config.architectures
|
||||
warner = logger.warning if model.__class__.__name__ in archs else logger.info
|
||||
|
||||
@@ -20,6 +20,7 @@ import unittest
|
||||
from packaging import version
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
@@ -29,6 +30,7 @@ from transformers import (
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
@@ -36,13 +38,21 @@ from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def get_some_linear_layer(model):
|
||||
if model.config.model_type == "gpt2":
|
||||
return model.transformer.h[0].mlp.c_fc
|
||||
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||
elif model.config.model_type == "opt":
|
||||
try:
|
||||
return model.decoder.layers[0].fc1
|
||||
except AttributeError:
|
||||
# for AutoModelforCausalLM
|
||||
return model.model.decoder.layers[0].fc1
|
||||
else:
|
||||
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -68,6 +78,10 @@ if is_torch_available():
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@@ -225,28 +239,6 @@ class Bnb4BitTest(Base4bitTest):
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_raise_on_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_4bit.save_pretrained(tmpdirname)
|
||||
|
||||
def test_raise_if_config_and_load_in_4bit(self):
|
||||
r"""
|
||||
Test that loading the model with the config and `load_in_4bit` raises an error
|
||||
"""
|
||||
bnb_config = BitsAndBytesConfig()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=bnb_config,
|
||||
load_in_4bit=True,
|
||||
device_map="auto",
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
@@ -346,8 +338,6 @@ class Bnb4BitT5Test(unittest.TestCase):
|
||||
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
||||
both cases.
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
# test with `t5-small`
|
||||
@@ -521,3 +511,140 @@ class Bnb4BitTestTraining(Base4bitTest):
|
||||
class Bnb4BitGPT2Test(Bnb4BitTest):
|
||||
model_name = "gpt2-xl"
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class BaseSerializationTest(unittest.TestCase):
|
||||
model_name = "facebook/opt-125m"
|
||||
input_text = "Mars colonists' favorite meals are"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
|
||||
r"""
|
||||
Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default.
|
||||
See ExtendedSerializationTest class for more params combinations.
|
||||
"""
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
self.quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type=quant_type,
|
||||
bnb_4bit_use_double_quant=double_quant,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
model_0 = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=self.quantization_config,
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
|
||||
|
||||
config = AutoConfig.from_pretrained(tmpdirname)
|
||||
self.assertTrue(hasattr(config, "quantization_config"))
|
||||
|
||||
model_1 = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
|
||||
|
||||
# checking quantized linear module weight
|
||||
linear = get_some_linear_layer(model_1)
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
|
||||
self.assertTrue(hasattr(linear.weight, "quant_state"))
|
||||
self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
|
||||
|
||||
# checking memory footpring
|
||||
self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
|
||||
|
||||
# Matching all parameters and their quant_state items:
|
||||
d0 = dict(model_0.named_parameters())
|
||||
d1 = dict(model_1.named_parameters())
|
||||
self.assertTrue(d0.keys() == d1.keys())
|
||||
|
||||
for k in d0.keys():
|
||||
self.assertTrue(d0[k].shape == d1[k].shape)
|
||||
self.assertTrue(d0[k].device.type == d1[k].device.type)
|
||||
self.assertTrue(d0[k].device == d1[k].device)
|
||||
self.assertTrue(d0[k].dtype == d1[k].dtype)
|
||||
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
|
||||
|
||||
if isinstance(d0[k], bnb.nn.modules.Params4bit):
|
||||
for v0, v1 in zip(
|
||||
d0[k].quant_state.as_dict().values(),
|
||||
d1[k].quant_state.as_dict().values(),
|
||||
):
|
||||
if isinstance(v0, torch.Tensor):
|
||||
self.assertTrue(torch.equal(v0, v1.to(v0.device)))
|
||||
else:
|
||||
self.assertTrue(v0 == v1)
|
||||
|
||||
# comparing forward() outputs
|
||||
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
out_0 = model_0(**encoded_input)
|
||||
out_1 = model_1(**encoded_input)
|
||||
self.assertTrue(torch.equal(out_0["logits"], out_1["logits"]))
|
||||
|
||||
# comparing generate() outputs
|
||||
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
output_sequences_0 = model_0.generate(**encoded_input, max_new_tokens=10)
|
||||
output_sequences_1 = model_1.generate(**encoded_input, max_new_tokens=10)
|
||||
|
||||
def _decode(token):
|
||||
return tokenizer.decode(token, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(
|
||||
[_decode(x) for x in output_sequences_0],
|
||||
[_decode(x) for x in output_sequences_1],
|
||||
)
|
||||
|
||||
|
||||
class ExtendedSerializationTest(BaseSerializationTest):
|
||||
"""
|
||||
tests more combinations of parameters
|
||||
"""
|
||||
|
||||
def test_nf4_single_unsafe(self):
|
||||
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False)
|
||||
|
||||
def test_nf4_single_safe(self):
|
||||
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True)
|
||||
|
||||
def test_nf4_double_unsafe(self):
|
||||
self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False)
|
||||
|
||||
# nf4 double safetensors quantization is tested in test_serialization() method from the parent class
|
||||
|
||||
def test_fp4_single_unsafe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False)
|
||||
|
||||
def test_fp4_single_safe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True)
|
||||
|
||||
def test_fp4_double_unsafe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False)
|
||||
|
||||
def test_fp4_double_safe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
|
||||
|
||||
|
||||
class BloomSerializationTest(BaseSerializationTest):
|
||||
"""
|
||||
default BaseSerializationTest config tested with Bloom family model
|
||||
"""
|
||||
|
||||
model_name = "bigscience/bloom-560m"
|
||||
|
||||
|
||||
class GPTSerializationTest(BaseSerializationTest):
|
||||
"""
|
||||
default BaseSerializationTest config tested with GPT family model
|
||||
"""
|
||||
|
||||
model_name = "gpt2-xl"
|
||||
|
||||
Reference in New Issue
Block a user