[bnb] Let's make serialization of int8 models possible (#22177)

* make serialization of int8 models possible

* make fixup

* add docs

* add ability to push to hub and save pretrained

* fixes

* more addition

* more tests

* fix issues

* change variable

* clearer message

* adapt from suggestions

* few fixes

* remove unused function

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address last comments

* last warning

* clarify doc

* protect import

* Update src/transformers/modeling_utils.py

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-04-12 14:01:18 +02:00
committed by GitHub
parent 523ca4e016
commit 370f0ca18c
6 changed files with 274 additions and 19 deletions

View File

@@ -19,6 +19,7 @@ import unittest
from packaging import version
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
@@ -150,6 +151,13 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
def test_raise_if_config_and_load_in_8bit(self):
r"""
Test that loading the model with the config and `load_in_8bit` raises an error
@@ -165,13 +173,6 @@ class MixedInt8Test(BaseMixedInt8Test):
llm_int8_enable_fp32_cpu_offload=True,
)
def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
@@ -219,6 +220,77 @@ class MixedInt8Test(BaseMixedInt8Test):
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
def test_int8_serialization(self):
r"""
Test whether it is possible to serialize a model in 8-bit.
"""
from bitsandbytes.nn import Int8Params
with tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
# check that the file `quantization_config` is present
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)
def test_int8_serialization_sharded(self):
r"""
Test whether it is possible to serialize a model in 8-bit - sharded version.
"""
from bitsandbytes.nn import Int8Params
with tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname, max_shard_size="200MB")
# check that the file `quantization_config` is present
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)
def test_int8_from_pretrained(self):
r"""
Test whether loading a 8bit model from the Hub works as expected
"""
from bitsandbytes.nn import Int8Params
model_id = "ybelkada/bloom-1b7-8bit"
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_bitsandbytes
@require_accelerate
@@ -289,6 +361,38 @@ class MixedInt8T5Test(unittest.TestCase):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
def test_inference_with_keep_in_fp32_serialized(self):
r"""
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly on
a serialized model.
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
both cases.
"""
import bitsandbytes as bnb
from transformers import T5ForConditionalGeneration
# test with `t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):