[core / Quantization] Fix for 8bit serialization tests (#27234)
* fix for 8bit serialization * added regression tests. * fixup
This commit is contained in:
@@ -2110,7 +2110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# We're going to remove aliases before saving
|
# We're going to remove aliases before saving
|
||||||
ptrs = collections.defaultdict(list)
|
ptrs = collections.defaultdict(list)
|
||||||
for name, tensor in state_dict.items():
|
for name, tensor in state_dict.items():
|
||||||
ptrs[id_tensor_storage(tensor)].append(name)
|
# Sometimes in the state_dict we have non-tensor objects.
|
||||||
|
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
ptrs[id_tensor_storage(tensor)].append(name)
|
||||||
|
else:
|
||||||
|
# In the non-tensor case, fall back to the pointer of the object itself
|
||||||
|
ptrs[id(tensor)].append(name)
|
||||||
|
|
||||||
# These are all the pointers of shared tensors.
|
# These are all the pointers of shared tensors.
|
||||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||||
|
|||||||
@@ -369,6 +369,33 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_int8_serialization_regression(self):
|
||||||
|
r"""
|
||||||
|
Test whether it is possible to serialize a model in 8-bit - using not safetensors
|
||||||
|
"""
|
||||||
|
from bitsandbytes.nn import Int8Params
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
self.model_8bit.save_pretrained(tmpdirname, safe_serialization=False)
|
||||||
|
|
||||||
|
# check that the file `quantization_config` is present
|
||||||
|
config = AutoConfig.from_pretrained(tmpdirname)
|
||||||
|
self.assertTrue(hasattr(config, "quantization_config"))
|
||||||
|
|
||||||
|
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
|
||||||
|
|
||||||
|
linear = get_some_linear_layer(model_from_saved)
|
||||||
|
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||||
|
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||||
|
|
||||||
|
# generate
|
||||||
|
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||||
|
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
||||||
|
)
|
||||||
|
|
||||||
def test_int8_serialization_sharded(self):
|
def test_int8_serialization_sharded(self):
|
||||||
r"""
|
r"""
|
||||||
Test whether it is possible to serialize a model in 8-bit - sharded version.
|
Test whether it is possible to serialize a model in 8-bit - sharded version.
|
||||||
|
|||||||
Reference in New Issue
Block a user