[bnb] Fix bnb config json serialization (#24137)
* fix bnb config json serialization * forward contrib credits from discussions --------- Co-authored-by: Andrechang <Andrechang@users.noreply.github.com>
This commit is contained in:
committed by
Sylvain Gugger
parent
a272e4135c
commit
fd59fc1a7f
@@ -784,6 +784,13 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
):
|
):
|
||||||
serializable_config_dict[key] = value
|
serializable_config_dict[key] = value
|
||||||
|
|
||||||
|
if hasattr(self, "quantization_config"):
|
||||||
|
serializable_config_dict["quantization_config"] = (
|
||||||
|
self.quantization_config.to_dict()
|
||||||
|
if not isinstance(self.quantization_config, dict)
|
||||||
|
else self.quantization_config
|
||||||
|
)
|
||||||
|
|
||||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||||
|
|
||||||
return serializable_config_dict
|
return serializable_config_dict
|
||||||
|
|||||||
@@ -111,6 +111,19 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def test_quantization_config_json_serialization(self):
|
||||||
|
r"""
|
||||||
|
A simple test to check if the quantization config is correctly serialized and deserialized
|
||||||
|
"""
|
||||||
|
config = self.model_4bit.config
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(config, "quantization_config"))
|
||||||
|
|
||||||
|
_ = config.to_dict()
|
||||||
|
_ = config.to_diff_dict()
|
||||||
|
|
||||||
|
_ = config.to_json_string()
|
||||||
|
|
||||||
def test_memory_footprint(self):
|
def test_memory_footprint(self):
|
||||||
r"""
|
r"""
|
||||||
A simple test to check if the model conversion has been done correctly by checking on the
|
A simple test to check if the model conversion has been done correctly by checking on the
|
||||||
|
|||||||
@@ -118,6 +118,19 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def test_quantization_config_json_serialization(self):
|
||||||
|
r"""
|
||||||
|
A simple test to check if the quantization config is correctly serialized and deserialized
|
||||||
|
"""
|
||||||
|
config = self.model_8bit.config
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(config, "quantization_config"))
|
||||||
|
|
||||||
|
_ = config.to_dict()
|
||||||
|
_ = config.to_diff_dict()
|
||||||
|
|
||||||
|
_ = config.to_json_string()
|
||||||
|
|
||||||
def test_memory_footprint(self):
|
def test_memory_footprint(self):
|
||||||
r"""
|
r"""
|
||||||
A simple test to check if the model conversion has been done correctly by checking on the
|
A simple test to check if the model conversion has been done correctly by checking on the
|
||||||
|
|||||||
Reference in New Issue
Block a user