[bnb] Fix bnb config json serialization (#24137)

* fix bnb config json serialization

* forward contrib credits from discussions

---------

Co-authored-by: Andrechang <Andrechang@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-06-09 13:41:14 +02:00
committed by GitHub
parent e2972dffdd
commit a6d05d55f6
3 changed files with 33 additions and 0 deletions

View File

@@ -784,6 +784,13 @@ class PretrainedConfig(PushToHubMixin):
): ):
serializable_config_dict[key] = value serializable_config_dict[key] = value
if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
self.dict_torch_dtype_to_str(serializable_config_dict) self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict return serializable_config_dict

View File

@@ -111,6 +111,19 @@ class Bnb4BitTest(Base4bitTest):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_quantization_config_json_serialization(self):
r"""
A simple test to check if the quantization config is correctly serialized and deserialized
"""
config = self.model_4bit.config
self.assertTrue(hasattr(config, "quantization_config"))
_ = config.to_dict()
_ = config.to_diff_dict()
_ = config.to_json_string()
def test_memory_footprint(self): def test_memory_footprint(self):
r""" r"""
A simple test to check if the model conversion has been done correctly by checking on the A simple test to check if the model conversion has been done correctly by checking on the

View File

@@ -118,6 +118,19 @@ class MixedInt8Test(BaseMixedInt8Test):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_quantization_config_json_serialization(self):
r"""
A simple test to check if the quantization config is correctly serialized and deserialized
"""
config = self.model_8bit.config
self.assertTrue(hasattr(config, "quantization_config"))
_ = config.to_dict()
_ = config.to_diff_dict()
_ = config.to_json_string()
def test_memory_footprint(self): def test_memory_footprint(self):
r""" r"""
A simple test to check if the model conversion has been done correctly by checking on the A simple test to check if the model conversion has been done correctly by checking on the