From fd59fc1a7f86ea8e2dca79d7063eb809650cefd7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 9 Jun 2023 13:41:14 +0200 Subject: [PATCH] [`bnb`] Fix bnb config json serialization (#24137) * fix bnb config json serialization * forward contrib credits from discussions --------- Co-authored-by: Andrechang --- src/transformers/configuration_utils.py | 7 +++++++ tests/bitsandbytes/test_4bit.py | 13 +++++++++++++ tests/bitsandbytes/test_mixed_int8.py | 13 +++++++++++++ 3 files changed, 33 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 6e8104c0cd..1bcdef152a 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -784,6 +784,13 @@ class PretrainedConfig(PushToHubMixin): ): serializable_config_dict[key] = value + if hasattr(self, "quantization_config"): + serializable_config_dict["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + self.dict_torch_dtype_to_str(serializable_config_dict) return serializable_config_dict diff --git a/tests/bitsandbytes/test_4bit.py b/tests/bitsandbytes/test_4bit.py index 0b6445b2c1..182dfb9a17 100644 --- a/tests/bitsandbytes/test_4bit.py +++ b/tests/bitsandbytes/test_4bit.py @@ -111,6 +111,19 @@ class Bnb4BitTest(Base4bitTest): gc.collect() torch.cuda.empty_cache() + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_4bit.config + + self.assertTrue(hasattr(config, "quantization_config")) + + _ = config.to_dict() + _ = config.to_diff_dict() + + _ = config.to_json_string() + def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the diff --git a/tests/bitsandbytes/test_mixed_int8.py b/tests/bitsandbytes/test_mixed_int8.py index 09157a251e..7927045d78 100644 --- a/tests/bitsandbytes/test_mixed_int8.py +++ b/tests/bitsandbytes/test_mixed_int8.py @@ -118,6 +118,19 @@ class MixedInt8Test(BaseMixedInt8Test): gc.collect() torch.cuda.empty_cache() + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_8bit.config + + self.assertTrue(hasattr(config, "quantization_config")) + + _ = config.to_dict() + _ = config.to_diff_dict() + + _ = config.to_json_string() + def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the