From c8b26096d4b092b96e65ee88367624bf7b837c36 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:12:35 +0200 Subject: [PATCH] [`core`] fix 4bit `num_parameters` (#26132) * fix 4bit `num_parameters` * stronger check --- src/transformers/modeling_utils.py | 27 ++++++++++++++++++++++++--- tests/quantization/bnb/test_4bit.py | 11 +++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 68df6ccdb8..9279950d3a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -989,12 +989,33 @@ class ModuleUtilsMixin: embedding_param_names = [ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) ] - non_embedding_parameters = [ + total_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] - return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: - return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + total_parameters = list(self.parameters()) + + total_numel = [] + is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) + if is_loaded_in_4bit: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + else: + raise ValueError( + "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" + " make sure to install bitsandbytes with `pip install bitsandbytes`." + ) + + for param in total_parameters: + if param.requires_grad or not only_trainable: + # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are + # used for the 4bit quantization (uint8 tensors are stored) + if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): + total_numel.append(param.numel() * 2) + else: + total_numel.append(param.numel()) + + return sum(total_numel) def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int: """ diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 00d4109ca6..ce1dd336e9 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -118,6 +118,17 @@ class Bnb4BitTest(Base4bitTest): gc.collect() torch.cuda.empty_cache() + def test_quantization_num_parameters(self): + r""" + Test if the number of returned parameters is correct + + See: https://github.com/huggingface/transformers/issues/25978 + """ + num_params_4bit = self.model_4bit.num_parameters() + num_params_fp16 = self.model_fp16.num_parameters() + + self.assertEqual(num_params_4bit, num_params_fp16) + def test_quantization_config_json_serialization(self): r""" A simple test to check if the quantization config is correctly serialized and deserialized