[core / Quantization] Fix for 8bit serialization tests (#27234)

* fix for 8bit serialization

* added regression tests.

* fixup
This commit is contained in:
Younes Belkada
2023-11-02 12:03:51 +01:00
committed by GitHub
parent c52e429b1c
commit 9b25c164bd
2 changed files with 34 additions and 1 deletions

View File

@@ -369,6 +369,33 @@ class MixedInt8Test(BaseMixedInt8Test):
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)
def test_int8_serialization_regression(self):
r"""
Test whether it is possible to serialize a model in 8-bit - using not safetensors
"""
from bitsandbytes.nn import Int8Params
with tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname, safe_serialization=False)
# check that the file `quantization_config` is present
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
linear = get_some_linear_layer(model_from_saved)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)
def test_int8_serialization_sharded(self):
r"""
Test whether it is possible to serialize a model in 8-bit - sharded version.