[core] fix 4bit num_parameters (#26132)

* fix 4bit `num_parameters`

* stronger check
This commit is contained in:
Younes Belkada
2023-09-13 14:12:35 +02:00
committed by GitHub
parent 7db1ad63d9
commit c8b26096d4
2 changed files with 35 additions and 3 deletions

View File

@@ -118,6 +118,17 @@ class Bnb4BitTest(Base4bitTest):
gc.collect()
torch.cuda.empty_cache()
def test_quantization_num_parameters(self):
r"""
Test if the number of returned parameters is correct
See: https://github.com/huggingface/transformers/issues/25978
"""
num_params_4bit = self.model_4bit.num_parameters()
num_params_fp16 = self.model_fp16.num_parameters()
self.assertEqual(num_params_4bit, num_params_fp16)
def test_quantization_config_json_serialization(self):
r"""
A simple test to check if the quantization config is correctly serialized and deserialized