From 9b25c164bdb2754002e118065bc5045436b72773 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 2 Nov 2023 12:03:51 +0100 Subject: [PATCH] [`core` / `Quantization`] Fix for 8bit serialization tests (#27234) * fix for 8bit serialization * added regression tests. * fixup --- src/transformers/modeling_utils.py | 8 ++++++- tests/quantization/bnb/test_mixed_int8.py | 27 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e48c98c791..fcb51e6a56 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2110,7 +2110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # We're going to remove aliases before saving ptrs = collections.defaultdict(list) for name, tensor in state_dict.items(): - ptrs[id_tensor_storage(tensor)].append(name) + # Sometimes in the state_dict we have non-tensor objects. + # e.g. in bitsandbytes we have some `str` objects in the state_dict + if isinstance(tensor, torch.Tensor): + ptrs[id_tensor_storage(tensor)].append(name) + else: + # In the non-tensor case, fall back to the pointer of the object itself + ptrs[id(tensor)].append(name) # These are all the pointers of shared tensors. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 3be1e5582a..da2ce55d31 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -369,6 +369,33 @@ class MixedInt8Test(BaseMixedInt8Test): self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT ) + def test_int8_serialization_regression(self): + r""" + Test whether it is possible to serialize a model in 8-bit - using not safetensors + """ + from bitsandbytes.nn import Int8Params + + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_8bit.save_pretrained(tmpdirname, safe_serialization=False) + + # check that the file `quantization_config` is present + config = AutoConfig.from_pretrained(tmpdirname) + self.assertTrue(hasattr(config, "quantization_config")) + + model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") + + linear = get_some_linear_layer(model_from_saved) + self.assertTrue(linear.weight.__class__ == Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # generate + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + self.assertEqual( + self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT + ) + def test_int8_serialization_sharded(self): r""" Test whether it is possible to serialize a model in 8-bit - sharded version.