[bnb] Let's make serialization of int8 models possible (#22177)
* make serialization of int8 models possible * make fixup * add docs * add ability to push to hub and save pretrained * fixes * more addition * more tests * fix issues * change variable * clearer message * adapt from suggestions * few fixes * remove unused function * Update src/transformers/utils/quantization_config.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address last comments * last warning * clarify doc * protect import * Update src/transformers/modeling_utils.py * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import unittest
|
||||
from packaging import version
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
@@ -150,6 +151,13 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_warns_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
def test_raise_if_config_and_load_in_8bit(self):
|
||||
r"""
|
||||
Test that loading the model with the config and `load_in_8bit` raises an error
|
||||
@@ -165,13 +173,6 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
|
||||
def test_warns_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
@@ -219,6 +220,77 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto")
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
|
||||
def test_int8_serialization(self):
|
||||
r"""
|
||||
Test whether it is possible to serialize a model in 8-bit.
|
||||
"""
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
# check that the file `quantization_config` is present
|
||||
config = AutoConfig.from_pretrained(tmpdirname)
|
||||
self.assertTrue(hasattr(config, "quantization_config"))
|
||||
|
||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
|
||||
|
||||
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertEqual(
|
||||
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
||||
)
|
||||
|
||||
def test_int8_serialization_sharded(self):
|
||||
r"""
|
||||
Test whether it is possible to serialize a model in 8-bit - sharded version.
|
||||
"""
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname, max_shard_size="200MB")
|
||||
|
||||
# check that the file `quantization_config` is present
|
||||
config = AutoConfig.from_pretrained(tmpdirname)
|
||||
self.assertTrue(hasattr(config, "quantization_config"))
|
||||
|
||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertEqual(
|
||||
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
||||
)
|
||||
|
||||
def test_int8_from_pretrained(self):
|
||||
r"""
|
||||
Test whether loading a 8bit model from the Hub works as expected
|
||||
"""
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
model_id = "ybelkada/bloom-1b7-8bit"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@@ -289,6 +361,38 @@ class MixedInt8T5Test(unittest.TestCase):
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
def test_inference_with_keep_in_fp32_serialized(self):
|
||||
r"""
|
||||
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly on
|
||||
a serialized model.
|
||||
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
||||
both cases.
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
# test with `t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
|
||||
|
||||
# there was a bug with decoders - this test checks that it is fixed
|
||||
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
# test with `flan-t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
|
||||
)
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
|
||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user