[core] fix 4bit num_parameters (#26132)
* fix 4bit `num_parameters` * stronger check
This commit is contained in:
@@ -989,12 +989,33 @@ class ModuleUtilsMixin:
|
|||||||
embedding_param_names = [
|
embedding_param_names = [
|
||||||
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
|
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
|
||||||
]
|
]
|
||||||
non_embedding_parameters = [
|
total_parameters = [
|
||||||
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
||||||
]
|
]
|
||||||
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
|
||||||
else:
|
else:
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
total_parameters = list(self.parameters())
|
||||||
|
|
||||||
|
total_numel = []
|
||||||
|
is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
|
||||||
|
if is_loaded_in_4bit:
|
||||||
|
if is_bitsandbytes_available():
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
|
||||||
|
" make sure to install bitsandbytes with `pip install bitsandbytes`."
|
||||||
|
)
|
||||||
|
|
||||||
|
for param in total_parameters:
|
||||||
|
if param.requires_grad or not only_trainable:
|
||||||
|
# For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
|
||||||
|
# used for the 4bit quantization (uint8 tensors are stored)
|
||||||
|
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
|
||||||
|
total_numel.append(param.numel() * 2)
|
||||||
|
else:
|
||||||
|
total_numel.append(param.numel())
|
||||||
|
|
||||||
|
return sum(total_numel)
|
||||||
|
|
||||||
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
|
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -118,6 +118,17 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def test_quantization_num_parameters(self):
|
||||||
|
r"""
|
||||||
|
Test if the number of returned parameters is correct
|
||||||
|
|
||||||
|
See: https://github.com/huggingface/transformers/issues/25978
|
||||||
|
"""
|
||||||
|
num_params_4bit = self.model_4bit.num_parameters()
|
||||||
|
num_params_fp16 = self.model_fp16.num_parameters()
|
||||||
|
|
||||||
|
self.assertEqual(num_params_4bit, num_params_fp16)
|
||||||
|
|
||||||
def test_quantization_config_json_serialization(self):
|
def test_quantization_config_json_serialization(self):
|
||||||
r"""
|
r"""
|
||||||
A simple test to check if the quantization config is correctly serialized and deserialized
|
A simple test to check if the quantization config is correctly serialized and deserialized
|
||||||
|
|||||||
Reference in New Issue
Block a user