[core / Quantization] Fix for 8bit serialization tests (#27234)
* fix for 8bit serialization * added regression tests. * fixup
This commit is contained in:
@@ -2110,7 +2110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# We're going to remove aliases before saving
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in state_dict.items():
|
||||
# Sometimes in the state_dict we have non-tensor objects.
|
||||
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
else:
|
||||
# In the non-tensor case, fall back to the pointer of the object itself
|
||||
ptrs[id(tensor)].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||
|
||||
@@ -369,6 +369,33 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
||||
)
|
||||
|
||||
def test_int8_serialization_regression(self):
|
||||
r"""
|
||||
Test whether it is possible to serialize a model in 8-bit - using not safetensors
|
||||
"""
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
|
||||
# check that the file `quantization_config` is present
|
||||
config = AutoConfig.from_pretrained(tmpdirname)
|
||||
self.assertTrue(hasattr(config, "quantization_config"))
|
||||
|
||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
|
||||
|
||||
linear = get_some_linear_layer(model_from_saved)
|
||||
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertEqual(
|
||||
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
||||
)
|
||||
|
||||
def test_int8_serialization_sharded(self):
|
||||
r"""
|
||||
Test whether it is possible to serialize a model in 8-bit - sharded version.
|
||||
|
||||
Reference in New Issue
Block a user