[bnb] Let's make serialization of 4bit models possible (#26037)
* updated bitsandbytes.py * rm test_raise_* from test_4bit.py * add test_4bit_serialization.py * modeling_utils bulk edits * bnb_ver 0.41.3 in integrations/bitsandbytes.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * @slow reinstated Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * bnb ver 0.41.3 in src/transformers/modeling_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * rm bnb version todo in integrations/bitsandbytes.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * moved 4b serialization tests to test_4bit * tests upd for opt * to torch_device Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * ruff fixes to tests * rm redundant bnb version check in mod_utils Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * restore _hf_peft_config_loaded modeling_utils.py::2188 Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * restore _hf_peft_config_loaded test in modeling_utils.py::2199 Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fixed NOT getattr(self, "is_8bit_serializable") Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * setting model.is_4bit_serializable * rm separate fp16_statistics arg from set_module... * rm else branch in integrations::bnb::set_module * bnb 4bit dtype check * upd comment on 4bit weights * upd tests for FP4 safe --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,7 @@ if is_accelerate_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
|
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
|
||||||
"""
|
"""
|
||||||
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
||||||
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
|
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
|
||||||
@@ -37,8 +37,8 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
|||||||
The device on which to set the tensor.
|
The device on which to set the tensor.
|
||||||
value (`torch.Tensor`, *optional*):
|
value (`torch.Tensor`, *optional*):
|
||||||
The value of the tensor (useful when going from the meta device to any other device).
|
The value of the tensor (useful when going from the meta device to any other device).
|
||||||
fp16_statistics (`torch.HalfTensor`, *optional*):
|
quantized_stats (`dict[str, Any]`, *optional*):
|
||||||
The list of fp16 statistics to set on the module, used for serialization.
|
Dict with items for either 4-bit or 8-bit serialization
|
||||||
"""
|
"""
|
||||||
# Recurse if needed
|
# Recurse if needed
|
||||||
if "." in tensor_name:
|
if "." in tensor_name:
|
||||||
@@ -58,8 +58,7 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
|||||||
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
|
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
|
||||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
|
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
|
||||||
|
|
||||||
is_4bit = False
|
prequantized_loading = quantized_stats is not None
|
||||||
is_8bit = False
|
|
||||||
if is_buffer or not is_bitsandbytes_available():
|
if is_buffer or not is_bitsandbytes_available():
|
||||||
is_8bit = False
|
is_8bit = False
|
||||||
is_4bit = False
|
is_4bit = False
|
||||||
@@ -74,32 +73,53 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
|||||||
new_value = old_value.to(device)
|
new_value = old_value.to(device)
|
||||||
elif isinstance(value, torch.Tensor):
|
elif isinstance(value, torch.Tensor):
|
||||||
new_value = value.to("cpu")
|
new_value = value.to("cpu")
|
||||||
if value.dtype == torch.int8:
|
|
||||||
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
|
|
||||||
"0.37.2"
|
|
||||||
)
|
|
||||||
if not is_8bit_serializable:
|
|
||||||
raise ValueError(
|
|
||||||
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
|
|
||||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
new_value = torch.tensor(value, device="cpu")
|
new_value = torch.tensor(value, device="cpu")
|
||||||
|
|
||||||
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
|
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
|
||||||
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
||||||
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
|
if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
|
||||||
new_value = new_value.T
|
new_value = new_value.T
|
||||||
|
|
||||||
kwargs = old_value.__dict__
|
kwargs = old_value.__dict__
|
||||||
if is_8bit:
|
|
||||||
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
|
|
||||||
elif is_4bit:
|
|
||||||
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
|
|
||||||
|
|
||||||
|
if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_8bit:
|
||||||
|
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
|
||||||
|
"0.37.2"
|
||||||
|
)
|
||||||
|
if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
|
||||||
|
raise ValueError(
|
||||||
|
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
|
||||||
|
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||||
|
)
|
||||||
|
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
|
||||||
|
if prequantized_loading:
|
||||||
|
setattr(new_value, "SCB", quantized_stats["SCB"].to(device))
|
||||||
|
elif is_4bit:
|
||||||
|
if prequantized_loading:
|
||||||
|
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
|
||||||
|
"0.41.3"
|
||||||
|
)
|
||||||
|
if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable:
|
||||||
|
raise ValueError(
|
||||||
|
"Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
|
||||||
|
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||||
|
)
|
||||||
|
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||||
|
data=new_value,
|
||||||
|
quantized_stats=quantized_stats,
|
||||||
|
requires_grad=False,
|
||||||
|
device=device,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
|
||||||
module._parameters[tensor_name] = new_value
|
module._parameters[tensor_name] = new_value
|
||||||
if fp16_statistics is not None:
|
|
||||||
setattr(module.weight, "SCB", fp16_statistics.to(device))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if value is None:
|
if value is None:
|
||||||
@@ -117,7 +137,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
|||||||
|
|
||||||
|
|
||||||
def _replace_with_bnb_linear(
|
def _replace_with_bnb_linear(
|
||||||
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
|
model,
|
||||||
|
modules_to_not_convert=None,
|
||||||
|
current_key_name=None,
|
||||||
|
quantization_config=None,
|
||||||
|
has_been_replaced=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Private method that wraps the recursion for module replacement.
|
Private method that wraps the recursion for module replacement.
|
||||||
|
|||||||
@@ -675,6 +675,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
is_quantized=False,
|
is_quantized=False,
|
||||||
is_safetensors=False,
|
is_safetensors=False,
|
||||||
keep_in_fp32_modules=None,
|
keep_in_fp32_modules=None,
|
||||||
|
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||||
@@ -776,17 +777,41 @@ def _load_state_dict_into_meta_model(
|
|||||||
elif not is_quantized:
|
elif not is_quantized:
|
||||||
# For backward compatibility with older versions of `accelerate`
|
# For backward compatibility with older versions of `accelerate`
|
||||||
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
||||||
else:
|
elif param.dtype in (torch.int8, torch.uint8) and is_quantized:
|
||||||
if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
|
# handling newly quantized weights and loaded quantized weights
|
||||||
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
|
# edit the param.dtype restrictions and is_quantized condition when adding new quant methods
|
||||||
else:
|
quantized_stats = {}
|
||||||
fp16_statistics = None
|
|
||||||
|
if (param_name + ".quant_state.bitsandbytes__fp4" in state_dict) or (
|
||||||
|
param_name + ".quant_state.bitsandbytes__nf4" in state_dict
|
||||||
|
):
|
||||||
|
# 4bit loading. Collecting components for restoring quantized weight
|
||||||
|
# This can be expanded to make a universal call for any quantized weight loading
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if param_name + "." in k:
|
||||||
|
quantized_stats[k] = v
|
||||||
|
unexpected_keys.remove(k)
|
||||||
|
|
||||||
if "SCB" not in param_name:
|
|
||||||
set_module_quantized_tensor_to_device(
|
set_module_quantized_tensor_to_device(
|
||||||
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
|
model, param_name, param_device, value=param, quantized_stats=quantized_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
|
||||||
|
# 8bit loading. Could be combined with the above 4bit call.
|
||||||
|
# condition looks unreliable
|
||||||
|
fp16_statistics_key = param_name.replace("weight", "SCB")
|
||||||
|
unexpected_keys.remove(fp16_statistics_key)
|
||||||
|
set_module_quantized_tensor_to_device(
|
||||||
|
model,
|
||||||
|
param_name,
|
||||||
|
param_device,
|
||||||
|
value=param,
|
||||||
|
quantized_stats={"SCB": state_dict[fp16_statistics_key]},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# loading not quantized params in quantized model
|
||||||
|
set_module_quantized_tensor_to_device(model, param_name, param_device, value=param)
|
||||||
|
|
||||||
return error_msgs, offload_index, state_dict_index
|
return error_msgs, offload_index, state_dict_index
|
||||||
|
|
||||||
|
|
||||||
@@ -2197,15 +2222,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
and not getattr(self, "is_8bit_serializable", False)
|
and not getattr(self, "is_8bit_serializable", False)
|
||||||
and not _hf_peft_config_loaded
|
and not _hf_peft_config_loaded
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise NotImplementedError(
|
||||||
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
|
"You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
|
||||||
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
|
"If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the model has adapters attached, you can save the adapters
|
if (
|
||||||
if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded:
|
getattr(self, "is_loaded_in_4bit", False)
|
||||||
|
and not getattr(self, "is_4bit_serializable", False)
|
||||||
|
and not _hf_peft_config_loaded
|
||||||
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
|
"You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
|
||||||
|
"If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(self, "_awq_is_fused", False):
|
if getattr(self, "_awq_is_fused", False):
|
||||||
@@ -2774,8 +2803,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
use_safetensors = False
|
use_safetensors = False
|
||||||
|
|
||||||
if is_bitsandbytes_available():
|
if is_bitsandbytes_available():
|
||||||
|
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
|
||||||
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")
|
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")
|
||||||
else:
|
else:
|
||||||
|
is_4bit_serializable = False
|
||||||
is_8bit_serializable = False
|
is_8bit_serializable = False
|
||||||
|
|
||||||
if trust_remote_code is True:
|
if trust_remote_code is True:
|
||||||
@@ -3064,10 +3095,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if low_cpu_mem_usage is None:
|
if low_cpu_mem_usage is None:
|
||||||
low_cpu_mem_usage = True
|
low_cpu_mem_usage = True
|
||||||
|
|
||||||
if (
|
if quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and (
|
||||||
is_8bit_serializable
|
(is_8bit_serializable and load_in_8bit) or (is_4bit_serializable and load_in_4bit)
|
||||||
and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES
|
|
||||||
and load_in_8bit
|
|
||||||
):
|
):
|
||||||
if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
|
if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -3077,8 +3106,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
config.quantization_config = quantization_config
|
config.quantization_config = quantization_config
|
||||||
elif (
|
elif (
|
||||||
is_8bit_serializable
|
(is_8bit_serializable or is_4bit_serializable)
|
||||||
and not load_in_8bit
|
and not (load_in_8bit or load_in_4bit)
|
||||||
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
|
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
|
||||||
):
|
):
|
||||||
quantization_config = config.quantization_config
|
quantization_config = config.quantization_config
|
||||||
@@ -3093,8 +3122,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
|
|
||||||
load_in_8bit = quantization_config.load_in_8bit
|
load_in_8bit = quantization_config.load_in_8bit
|
||||||
|
load_in_4bit = quantization_config.load_in_4bit
|
||||||
|
|
||||||
if load_in_8bit:
|
if load_in_8bit or load_in_4bit:
|
||||||
if torch_dtype is None:
|
if torch_dtype is None:
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
@@ -3112,12 +3142,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
elif (
|
elif (
|
||||||
not is_8bit_serializable
|
not is_8bit_serializable
|
||||||
and not load_in_8bit
|
and not (load_in_8bit or load_in_4bit)
|
||||||
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
|
and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
|
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
|
||||||
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
|
" `bitsandbytes` version to support 4 and 8 bit serialization. Please install the latest version of `bitsandbytes` with "
|
||||||
" `pip install --upgrade bitsandbytes`."
|
" `pip install --upgrade bitsandbytes`."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -3525,6 +3555,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
config.quantization_config = quantization_config
|
config.quantization_config = quantization_config
|
||||||
model.is_8bit_serializable = is_8bit_serializable
|
model.is_8bit_serializable = is_8bit_serializable
|
||||||
|
model.is_4bit_serializable = is_4bit_serializable
|
||||||
|
|
||||||
if load_in_8bit and torch_dtype is None:
|
if load_in_8bit and torch_dtype is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -4018,6 +4049,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model_key in model_state_dict
|
model_key in model_state_dict
|
||||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||||
):
|
):
|
||||||
|
if (
|
||||||
|
state_dict[checkpoint_key].shape[-1] == 1
|
||||||
|
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
|
||||||
|
):
|
||||||
|
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||||
|
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
mismatched_keys.append(
|
mismatched_keys.append(
|
||||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||||
)
|
)
|
||||||
@@ -4130,6 +4169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
is_quantized=is_quantized,
|
is_quantized=is_quantized,
|
||||||
is_safetensors=is_safetensors,
|
is_safetensors=is_safetensors,
|
||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
|
unexpected_keys=unexpected_keys,
|
||||||
)
|
)
|
||||||
error_msgs += new_error_msgs
|
error_msgs += new_error_msgs
|
||||||
else:
|
else:
|
||||||
@@ -4167,10 +4207,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||||
|
|
||||||
if is_quantized:
|
|
||||||
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
|
|
||||||
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
|
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
archs = [] if model.config.architectures is None else model.config.architectures
|
archs = [] if model.config.architectures is None else model.config.architectures
|
||||||
warner = logger.warning if model.__class__.__name__ in archs else logger.info
|
warner = logger.warning if model.__class__.__name__ in archs else logger.info
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import unittest
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
@@ -29,6 +30,7 @@ from transformers import (
|
|||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
is_bitsandbytes_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@@ -36,12 +38,20 @@ from transformers.testing_utils import (
|
|||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_some_linear_layer(model):
|
def get_some_linear_layer(model):
|
||||||
if model.config.model_type == "gpt2":
|
if model.config.model_type == "gpt2":
|
||||||
return model.transformer.h[0].mlp.c_fc
|
return model.transformer.h[0].mlp.c_fc
|
||||||
|
elif model.config.model_type == "opt":
|
||||||
|
try:
|
||||||
|
return model.decoder.layers[0].fc1
|
||||||
|
except AttributeError:
|
||||||
|
# for AutoModelforCausalLM
|
||||||
|
return model.model.decoder.layers[0].fc1
|
||||||
|
else:
|
||||||
return model.transformer.h[0].mlp.dense_4h_to_h
|
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||||
|
|
||||||
|
|
||||||
@@ -68,6 +78,10 @@ if is_torch_available():
|
|||||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||||
|
|
||||||
|
|
||||||
|
if is_bitsandbytes_available():
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -225,28 +239,6 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
|
|
||||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
def test_raise_on_save_pretrained(self):
|
|
||||||
r"""
|
|
||||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
self.model_4bit.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
def test_raise_if_config_and_load_in_4bit(self):
|
|
||||||
r"""
|
|
||||||
Test that loading the model with the config and `load_in_4bit` raises an error
|
|
||||||
"""
|
|
||||||
bnb_config = BitsAndBytesConfig()
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
quantization_config=bnb_config,
|
|
||||||
load_in_4bit=True,
|
|
||||||
device_map="auto",
|
|
||||||
bnb_4bit_quant_type="nf4",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_device_and_dtype_assignment(self):
|
def test_device_and_dtype_assignment(self):
|
||||||
r"""
|
r"""
|
||||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||||
@@ -346,8 +338,6 @@ class Bnb4BitT5Test(unittest.TestCase):
|
|||||||
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
||||||
both cases.
|
both cases.
|
||||||
"""
|
"""
|
||||||
import bitsandbytes as bnb
|
|
||||||
|
|
||||||
from transformers import T5ForConditionalGeneration
|
from transformers import T5ForConditionalGeneration
|
||||||
|
|
||||||
# test with `t5-small`
|
# test with `t5-small`
|
||||||
@@ -521,3 +511,140 @@ class Bnb4BitTestTraining(Base4bitTest):
|
|||||||
class Bnb4BitGPT2Test(Bnb4BitTest):
|
class Bnb4BitGPT2Test(Bnb4BitTest):
|
||||||
model_name = "gpt2-xl"
|
model_name = "gpt2-xl"
|
||||||
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
|
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
|
||||||
|
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
@slow
|
||||||
|
class BaseSerializationTest(unittest.TestCase):
|
||||||
|
model_name = "facebook/opt-125m"
|
||||||
|
input_text = "Mars colonists' favorite meals are"
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
|
||||||
|
r"""
|
||||||
|
Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default.
|
||||||
|
See ExtendedSerializationTest class for more params combinations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
self.quantization_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type=quant_type,
|
||||||
|
bnb_4bit_use_double_quant=double_quant,
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
model_0 = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
quantization_config=self.quantization_config,
|
||||||
|
device_map=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(tmpdirname)
|
||||||
|
self.assertTrue(hasattr(config, "quantization_config"))
|
||||||
|
|
||||||
|
model_1 = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
|
||||||
|
|
||||||
|
# checking quantized linear module weight
|
||||||
|
linear = get_some_linear_layer(model_1)
|
||||||
|
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
|
||||||
|
self.assertTrue(hasattr(linear.weight, "quant_state"))
|
||||||
|
self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
|
||||||
|
|
||||||
|
# checking memory footpring
|
||||||
|
self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
|
||||||
|
|
||||||
|
# Matching all parameters and their quant_state items:
|
||||||
|
d0 = dict(model_0.named_parameters())
|
||||||
|
d1 = dict(model_1.named_parameters())
|
||||||
|
self.assertTrue(d0.keys() == d1.keys())
|
||||||
|
|
||||||
|
for k in d0.keys():
|
||||||
|
self.assertTrue(d0[k].shape == d1[k].shape)
|
||||||
|
self.assertTrue(d0[k].device.type == d1[k].device.type)
|
||||||
|
self.assertTrue(d0[k].device == d1[k].device)
|
||||||
|
self.assertTrue(d0[k].dtype == d1[k].dtype)
|
||||||
|
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
|
||||||
|
|
||||||
|
if isinstance(d0[k], bnb.nn.modules.Params4bit):
|
||||||
|
for v0, v1 in zip(
|
||||||
|
d0[k].quant_state.as_dict().values(),
|
||||||
|
d1[k].quant_state.as_dict().values(),
|
||||||
|
):
|
||||||
|
if isinstance(v0, torch.Tensor):
|
||||||
|
self.assertTrue(torch.equal(v0, v1.to(v0.device)))
|
||||||
|
else:
|
||||||
|
self.assertTrue(v0 == v1)
|
||||||
|
|
||||||
|
# comparing forward() outputs
|
||||||
|
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
out_0 = model_0(**encoded_input)
|
||||||
|
out_1 = model_1(**encoded_input)
|
||||||
|
self.assertTrue(torch.equal(out_0["logits"], out_1["logits"]))
|
||||||
|
|
||||||
|
# comparing generate() outputs
|
||||||
|
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
output_sequences_0 = model_0.generate(**encoded_input, max_new_tokens=10)
|
||||||
|
output_sequences_1 = model_1.generate(**encoded_input, max_new_tokens=10)
|
||||||
|
|
||||||
|
def _decode(token):
|
||||||
|
return tokenizer.decode(token, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
[_decode(x) for x in output_sequences_0],
|
||||||
|
[_decode(x) for x in output_sequences_1],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedSerializationTest(BaseSerializationTest):
|
||||||
|
"""
|
||||||
|
tests more combinations of parameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_nf4_single_unsafe(self):
|
||||||
|
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False)
|
||||||
|
|
||||||
|
def test_nf4_single_safe(self):
|
||||||
|
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True)
|
||||||
|
|
||||||
|
def test_nf4_double_unsafe(self):
|
||||||
|
self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False)
|
||||||
|
|
||||||
|
# nf4 double safetensors quantization is tested in test_serialization() method from the parent class
|
||||||
|
|
||||||
|
def test_fp4_single_unsafe(self):
|
||||||
|
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False)
|
||||||
|
|
||||||
|
def test_fp4_single_safe(self):
|
||||||
|
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True)
|
||||||
|
|
||||||
|
def test_fp4_double_unsafe(self):
|
||||||
|
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False)
|
||||||
|
|
||||||
|
def test_fp4_double_safe(self):
|
||||||
|
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
|
class BloomSerializationTest(BaseSerializationTest):
|
||||||
|
"""
|
||||||
|
default BaseSerializationTest config tested with Bloom family model
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name = "bigscience/bloom-560m"
|
||||||
|
|
||||||
|
|
||||||
|
class GPTSerializationTest(BaseSerializationTest):
|
||||||
|
"""
|
||||||
|
default BaseSerializationTest config tested with GPT family model
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name = "gpt2-xl"
|
||||||
|
|||||||
Reference in New Issue
Block a user