Hqq serialization (#33141)
* HQQ model serialization attempt * fix hqq dispatch and unexpected keys * style * remove check_old_param * revert to check HQQLinear in quantizer_hqq.py * revert to check HQQLinear in quantizer_hqq.py * update HqqConfig default params * make ci happy * make ci happy * revert to HQQLinear check in quantizer_hqq.py * check hqq_min version 0.2.0 * set axis=1 as default in quantization_config.py * validate_env with hqq>=0.2.0 version message * deprecated hqq kwargs message * make ci happy * remove run_expected_keys_check hack + bump to 0.2.1 min hqq version * fix unexpected_keys hqq update * add pre_quantized check * add update_expected_keys to base quantizerr * ci base.py fix? * ci base.py fix? * fix "quantization typo" src/transformers/utils/quantization_config.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix post merge --------- Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
6
docs/source/en/quantization/hqq.md
Normal file → Executable file
6
docs/source/en/quantization/hqq.md
Normal file → Executable file
@@ -30,13 +30,13 @@ To quantize a model, you need to create an [`HqqConfig`]. There are two ways of
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
|
||||
|
||||
# Method 1: all linear layers will use the same quantization config
|
||||
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
|
||||
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||
```
|
||||
|
||||
``` Python
|
||||
# Method 2: each linear layer with the same tag will use a dedicated quantization config
|
||||
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
|
||||
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
|
||||
q4_config = {'nbits':4, 'group_size':64}
|
||||
q3_config = {'nbits':3, 'group_size':32}
|
||||
quant_config = HqqConfig(dynamic_config={
|
||||
'self_attn.q_proj':q4_config,
|
||||
'self_attn.k_proj':q4_config,
|
||||
|
||||
@@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_
|
||||
|
||||
has_been_replaced = True
|
||||
|
||||
# Add these fake parameters to avoid loading fail
|
||||
for att in ["W_q", "meta"]:
|
||||
setattr(module, att, None)
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _prepare_for_hqq_linear(
|
||||
module,
|
||||
@@ -97,7 +101,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
|
||||
|
||||
# Convert quantization_config to layer-wise config
|
||||
skip_modules = quantization_config.skip_modules
|
||||
quant_config = quantization_config.to_dict()
|
||||
quant_config = quantization_config.quant_config
|
||||
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
|
||||
|
||||
if any(key in linear_tags for key in quant_config.keys()):
|
||||
@@ -113,7 +117,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
|
||||
)
|
||||
|
||||
# We store quantization config as linear_tag -> hqq quant config
|
||||
model.config.quantization_config = patch_params
|
||||
model.config.quantization_config = {
|
||||
"quant_config": quant_config,
|
||||
"quant_method": quantization_config.quant_method,
|
||||
"skip_modules": skip_modules,
|
||||
}
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning("No linear modules were found in your model for quantization.")
|
||||
|
||||
@@ -934,12 +934,17 @@ def _load_state_dict_into_meta_model(
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
old_param = getattr(old_param, split)
|
||||
# Not all the attributes of a module are Parameters/Tensor
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if old_param is not None:
|
||||
if dtype is None:
|
||||
param = param.to(old_param.dtype)
|
||||
@@ -3819,6 +3824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
from_pt = not (from_tf | from_flax)
|
||||
|
||||
# load pt weights early so that we know which dtype to init the model under
|
||||
|
||||
if from_pt:
|
||||
if not is_sharded and state_dict is None:
|
||||
# Time to load the checkpoint
|
||||
@@ -4176,6 +4182,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
prefix = model.base_model_prefix
|
||||
|
||||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
|
||||
|
||||
def _fix_key(key):
|
||||
if "beta" in key:
|
||||
return key.replace("beta", "bias")
|
||||
@@ -4290,7 +4299,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
value = torch.empty(*param.size(), dtype=target_dtype)
|
||||
if (
|
||||
not is_quantized
|
||||
or getattr(hf_quantizer, "requires_parameters_quantization", False)
|
||||
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
|
||||
or not hf_quantizer.check_quantized_param(
|
||||
model, param_value=value, param_name=key, state_dict={}
|
||||
)
|
||||
|
||||
12
src/transformers/quantizers/base.py
Normal file → Executable file
12
src/transformers/quantizers/base.py
Normal file → Executable file
@@ -109,6 +109,18 @@ class HfQuantizer(ABC):
|
||||
"""
|
||||
return missing_keys
|
||||
|
||||
def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]:
|
||||
"""
|
||||
Override this method if you want to adjust the `update_expected_keys`.
|
||||
|
||||
Args:
|
||||
expected_keys (`List[str]`, *optional*):
|
||||
The list of the expected keys in the initialized model.
|
||||
loaded_keys (`List[str]`, *optional*):
|
||||
The list of the loaded keys in the checkpoint.
|
||||
"""
|
||||
return expected_keys
|
||||
|
||||
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
|
||||
"""
|
||||
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
|
||||
|
||||
@@ -62,7 +62,7 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not (is_hqq_available()):
|
||||
raise ImportError(
|
||||
"HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`"
|
||||
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
|
||||
)
|
||||
|
||||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
||||
@@ -91,6 +91,65 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
else:
|
||||
self.using_multi_gpu = len(set(device_map.values())) > 1
|
||||
|
||||
def update_missing_keys(
|
||||
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
|
||||
) -> List[str]:
|
||||
if self.pre_quantized:
|
||||
return [key for key in missing_keys if ("weight" not in key)]
|
||||
else:
|
||||
return missing_keys
|
||||
|
||||
# Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
|
||||
def update_expected_keys(
|
||||
self, model: "PreTrainedModel", expected_keys: List[str], loaded_keys: List[str]
|
||||
) -> List[str]:
|
||||
if not self.pre_quantized:
|
||||
return expected_keys
|
||||
|
||||
# Collects all quantizable (linear) layers
|
||||
def _find_hqq_quantizable_layers(model, layers):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, (torch.nn.Linear)):
|
||||
layers.add(module.name)
|
||||
_find_hqq_quantizable_layers(module, layers)
|
||||
|
||||
new_keys = set(expected_keys)
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
# Name modules
|
||||
for name, module in model.named_modules():
|
||||
module.name = name
|
||||
|
||||
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
|
||||
_valid_modules = set()
|
||||
_find_hqq_quantizable_layers(model, _valid_modules)
|
||||
_valid_modules -= set(model.config.quantization_config["skip_modules"])
|
||||
|
||||
# Append new expected layers based on _ref_keys
|
||||
_ref_keys = HQQLinear(
|
||||
linear_layer=None, quant_config=None, compute_dtype=torch.float16, device="cpu"
|
||||
).state_dict_keys() - {"bias"}
|
||||
|
||||
# Clean-up
|
||||
_rm_keys = set()
|
||||
for key in new_keys:
|
||||
if any(_module in key for _module in _valid_modules):
|
||||
_rm_keys.add(key)
|
||||
new_keys -= _rm_keys
|
||||
# At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
|
||||
|
||||
# Re-populate Linear/HQQLinear
|
||||
for _module in _valid_modules:
|
||||
if _module + ".weight" in loaded_keys:
|
||||
new_keys.add(_module + ".weight")
|
||||
else:
|
||||
new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
|
||||
if _module + ".bias" in loaded_keys:
|
||||
new_keys.add(_module + ".bias")
|
||||
|
||||
return list(new_keys)
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
@@ -99,9 +158,18 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import HQQLinear
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
|
||||
if self.pre_quantized:
|
||||
return (
|
||||
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
|
||||
and tensor_name != "weight"
|
||||
and tensor_name != "bias"
|
||||
)
|
||||
else:
|
||||
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
@@ -122,13 +190,43 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
layer_name = param_name.replace(".weight", "").replace(".bias", "")
|
||||
layer_name = ".".join(param_name.split(".")[:-1])
|
||||
parent_module = find_parent(model, layer_name)
|
||||
node = layer_name.split(".")[-1]
|
||||
|
||||
# Step 0: set module state_dict
|
||||
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
|
||||
# set module state_dict
|
||||
module_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if layer_name + "." in k:
|
||||
module_state_dict[k.split(".")[-1]] = v
|
||||
if unexpected_keys is not None and k in unexpected_keys:
|
||||
unexpected_keys.remove(k)
|
||||
|
||||
if self.pre_quantized:
|
||||
if isinstance(module, HQQLinear):
|
||||
return
|
||||
else:
|
||||
hqq_layer = HQQLinear(
|
||||
linear_layer=None,
|
||||
quant_config=None,
|
||||
compute_dtype=self.torch_dtype,
|
||||
device=target_device,
|
||||
)
|
||||
|
||||
hqq_layer.load_state_dict(module_state_dict)
|
||||
|
||||
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
|
||||
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
|
||||
|
||||
if self.using_multi_gpu:
|
||||
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
|
||||
|
||||
setattr(parent_module, node, hqq_layer)
|
||||
|
||||
# cleanup
|
||||
del module.__dict__, module
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
# Step 1: populate module with weight/bias from module state dict
|
||||
for key in module_state_dict:
|
||||
@@ -136,7 +234,6 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
|
||||
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
|
||||
# directly doesn't work.
|
||||
|
||||
if hasattr(module, "quant_config"):
|
||||
hqq_layer = HQQLinear(
|
||||
module,
|
||||
@@ -192,7 +289,7 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
return model
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
|
||||
@@ -92,6 +92,7 @@ ACCELERATE_MIN_VERSION = "0.26.0"
|
||||
FSDP_MIN_VERSION = "1.12.0"
|
||||
GGUF_MIN_VERSION = "0.10.0"
|
||||
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
||||
HQQ_MIN_VERSION = "0.2.1"
|
||||
|
||||
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
@@ -181,7 +182,7 @@ _torchao_available = _is_package_available("torchao")
|
||||
_torchdistx_available = _is_package_available("torchdistx")
|
||||
_torchvision_available = _is_package_available("torchvision")
|
||||
_mlx_available = _is_package_available("mlx")
|
||||
_hqq_available = _is_package_available("hqq")
|
||||
_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True)
|
||||
_tiktoken_available = _is_package_available("tiktoken")
|
||||
_blobfile_available = _is_package_available("blobfile")
|
||||
_liger_kernel_available = _is_package_available("liger_kernel")
|
||||
@@ -323,8 +324,8 @@ def is_torch_deterministic():
|
||||
return True
|
||||
|
||||
|
||||
def is_hqq_available():
|
||||
return _hqq_available
|
||||
def is_hqq_available(min_version: str = HQQ_MIN_VERSION):
|
||||
return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_pygments_available():
|
||||
|
||||
@@ -193,15 +193,9 @@ class HqqConfig(QuantizationConfigMixin):
|
||||
Number of bits. Supported values are (8, 4, 3, 2, 1).
|
||||
group_size (`int`, *optional*, defaults to 64):
|
||||
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
|
||||
quant_zero (`bool`, *optional*, defaults to `True`):
|
||||
Quantize the zero-point if set to `True`.
|
||||
quant_scale (`bool`, *optional*, defaults to `False`):
|
||||
Quantize the scaling if set to `True`.
|
||||
offload_meta (`bool`, *optional*, defaults to `False`):
|
||||
Offload the meta-data to the CPU if set to `True`.
|
||||
view_as_float (`bool`, *optional*, defaults to `False`):
|
||||
View the quantized weight as float (used in distributed training) if set to `True`.
|
||||
axis (`int`, *optional*, defaults to 0):
|
||||
axis (`Optional[int]`, *optional*):
|
||||
Axis along which grouping is performed. Supported values are 0 or 1.
|
||||
dynamic_config (dict, *optional*):
|
||||
Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
|
||||
@@ -216,11 +210,8 @@ class HqqConfig(QuantizationConfigMixin):
|
||||
self,
|
||||
nbits: int = 4,
|
||||
group_size: int = 64,
|
||||
quant_zero: bool = True,
|
||||
quant_scale: bool = False,
|
||||
offload_meta: bool = False,
|
||||
view_as_float: bool = False,
|
||||
axis: int = 0,
|
||||
axis: Optional[int] = None,
|
||||
dynamic_config: Optional[dict] = None,
|
||||
skip_modules: List[str] = ["lm_head"],
|
||||
**kwargs,
|
||||
@@ -228,6 +219,16 @@ class HqqConfig(QuantizationConfigMixin):
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
|
||||
|
||||
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
|
||||
if deprecated_key in kwargs:
|
||||
logger.info(
|
||||
deprecated_key + " is deprecated. This parameter will be ignored in quantization settings."
|
||||
)
|
||||
|
||||
if axis is None:
|
||||
axis = 1
|
||||
logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.")
|
||||
|
||||
if axis not in [0, 1]:
|
||||
raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
|
||||
|
||||
@@ -240,9 +241,6 @@ class HqqConfig(QuantizationConfigMixin):
|
||||
**{
|
||||
"nbits": nbits,
|
||||
"group_size": group_size,
|
||||
"quant_zero": quant_zero,
|
||||
"quant_scale": quant_scale,
|
||||
"offload_meta": offload_meta,
|
||||
"view_as_float": view_as_float,
|
||||
"axis": axis,
|
||||
}
|
||||
@@ -259,12 +257,26 @@ class HqqConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config: Dict[str, Any]):
|
||||
"""
|
||||
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
|
||||
"""
|
||||
instance = cls()
|
||||
instance.quant_config = config["quant_config"]
|
||||
instance.skip_modules = config["skip_modules"]
|
||||
return instance
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
return self.quant_config
|
||||
return {
|
||||
"quant_config": self.quant_config,
|
||||
"quant_method": self.quant_method,
|
||||
"skip_modules": self.skip_modules,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
config_dict = self.to_dict()
|
||||
|
||||
@@ -94,8 +94,7 @@ class HqqConfigTest(unittest.TestCase):
|
||||
quantization_config = HqqConfig()
|
||||
hqq_orig_config = quantization_config.to_dict()
|
||||
|
||||
for key in hqq_orig_config:
|
||||
self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key])
|
||||
self.assertEqual(quantization_config.quant_config, hqq_orig_config["quant_config"])
|
||||
|
||||
|
||||
@slow
|
||||
@@ -109,32 +108,7 @@ class HQQTest(unittest.TestCase):
|
||||
"""
|
||||
Simple LLM model testing fp16
|
||||
"""
|
||||
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
def test_f16_quantized_model_with_offloading(self):
|
||||
"""
|
||||
Simple LLM model testing bfp16 with meta-data offloading
|
||||
"""
|
||||
q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False}
|
||||
q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True}
|
||||
quant_config = HqqConfig(
|
||||
dynamic_config={
|
||||
"self_attn.q_proj": q4_config,
|
||||
"self_attn.k_proj": q4_config,
|
||||
"self_attn.v_proj": q4_config,
|
||||
"self_attn.o_proj": q4_config,
|
||||
"mlp.gate_proj": q3_config,
|
||||
"mlp.up_proj": q3_config,
|
||||
"mlp.down_proj": q3_config,
|
||||
}
|
||||
)
|
||||
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
@@ -157,7 +131,7 @@ class HQQTestMultiGPU(unittest.TestCase):
|
||||
Simple LLM model testing fp16 with multi-gpu
|
||||
"""
|
||||
|
||||
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
|
||||
quant_config = HqqConfig(nbits=8, group_size=64)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
|
||||
@@ -165,3 +139,44 @@ class HQQTestMultiGPU(unittest.TestCase):
|
||||
|
||||
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
|
||||
check_forward(self, hqq_runner.model)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_accelerate
|
||||
class HQQSerializationTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
cleanup()
|
||||
|
||||
def test_model_serialization(self):
|
||||
"""
|
||||
Simple HQQ LLM save/load test
|
||||
"""
|
||||
quant_config = HqqConfig(nbits=4, group_size=64)
|
||||
|
||||
hqq_runner = HQQLLMRunner(
|
||||
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
|
||||
)
|
||||
|
||||
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_ref = hqq_runner.model.forward(input_tensor).logits
|
||||
|
||||
# Save
|
||||
saved_model_id = "quant_model"
|
||||
hqq_runner.model.save_pretrained(saved_model_id)
|
||||
|
||||
# Remove old model
|
||||
del hqq_runner.model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load and check if the logits match
|
||||
model_loaded = AutoModelForCausalLM.from_pretrained(
|
||||
"quant_model", torch_dtype=torch.float16, device_map=torch_device, low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||
|
||||
self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)
|
||||
|
||||
Reference in New Issue
Block a user