[core] fix 4bit num_parameters (#26132)
* fix 4bit `num_parameters` * stronger check
This commit is contained in:
@@ -989,12 +989,33 @@ class ModuleUtilsMixin:
|
||||
embedding_param_names = [
|
||||
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
|
||||
]
|
||||
non_embedding_parameters = [
|
||||
total_parameters = [
|
||||
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
||||
]
|
||||
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
||||
else:
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||
total_parameters = list(self.parameters())
|
||||
|
||||
total_numel = []
|
||||
is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
|
||||
if is_loaded_in_4bit:
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
else:
|
||||
raise ValueError(
|
||||
"bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
|
||||
" make sure to install bitsandbytes with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
for param in total_parameters:
|
||||
if param.requires_grad or not only_trainable:
|
||||
# For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
|
||||
# used for the 4bit quantization (uint8 tensors are stored)
|
||||
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
|
||||
total_numel.append(param.numel() * 2)
|
||||
else:
|
||||
total_numel.append(param.numel())
|
||||
|
||||
return sum(total_numel)
|
||||
|
||||
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
|
||||
"""
|
||||
|
||||
@@ -118,6 +118,17 @@ class Bnb4BitTest(Base4bitTest):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_quantization_num_parameters(self):
|
||||
r"""
|
||||
Test if the number of returned parameters is correct
|
||||
|
||||
See: https://github.com/huggingface/transformers/issues/25978
|
||||
"""
|
||||
num_params_4bit = self.model_4bit.num_parameters()
|
||||
num_params_fp16 = self.model_fp16.num_parameters()
|
||||
|
||||
self.assertEqual(num_params_4bit, num_params_fp16)
|
||||
|
||||
def test_quantization_config_json_serialization(self):
|
||||
r"""
|
||||
A simple test to check if the quantization config is correctly serialized and deserialized
|
||||
|
||||
Reference in New Issue
Block a user