[bnb] Let's make serialization of 4bit models possible (#26037)
* updated bitsandbytes.py * rm test_raise_* from test_4bit.py * add test_4bit_serialization.py * modeling_utils bulk edits * bnb_ver 0.41.3 in integrations/bitsandbytes.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * @slow reinstated Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * bnb ver 0.41.3 in src/transformers/modeling_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * rm bnb version todo in integrations/bitsandbytes.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * moved 4b serialization tests to test_4bit * tests upd for opt * to torch_device Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * ruff fixes to tests * rm redundant bnb version check in mod_utils Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * restore _hf_peft_config_loaded modeling_utils.py::2188 Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * restore _hf_peft_config_loaded test in modeling_utils.py::2199 Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fixed NOT getattr(self, "is_8bit_serializable") Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * setting model.is_4bit_serializable * rm separate fp16_statistics arg from set_module... * rm else branch in integrations::bnb::set_module * bnb 4bit dtype check * upd comment on 4bit weights * upd tests for FP4 safe --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,7 @@ import unittest
|
||||
from packaging import version
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
@@ -29,6 +30,7 @@ from transformers import (
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
@@ -36,13 +38,21 @@ from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def get_some_linear_layer(model):
|
||||
if model.config.model_type == "gpt2":
|
||||
return model.transformer.h[0].mlp.c_fc
|
||||
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||
elif model.config.model_type == "opt":
|
||||
try:
|
||||
return model.decoder.layers[0].fc1
|
||||
except AttributeError:
|
||||
# for AutoModelforCausalLM
|
||||
return model.model.decoder.layers[0].fc1
|
||||
else:
|
||||
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -68,6 +78,10 @@ if is_torch_available():
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@@ -225,28 +239,6 @@ class Bnb4BitTest(Base4bitTest):
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_raise_on_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_4bit.save_pretrained(tmpdirname)
|
||||
|
||||
def test_raise_if_config_and_load_in_4bit(self):
|
||||
r"""
|
||||
Test that loading the model with the config and `load_in_4bit` raises an error
|
||||
"""
|
||||
bnb_config = BitsAndBytesConfig()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=bnb_config,
|
||||
load_in_4bit=True,
|
||||
device_map="auto",
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
@@ -346,8 +338,6 @@ class Bnb4BitT5Test(unittest.TestCase):
|
||||
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
||||
both cases.
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
# test with `t5-small`
|
||||
@@ -521,3 +511,140 @@ class Bnb4BitTestTraining(Base4bitTest):
|
||||
class Bnb4BitGPT2Test(Bnb4BitTest):
|
||||
model_name = "gpt2-xl"
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class BaseSerializationTest(unittest.TestCase):
|
||||
model_name = "facebook/opt-125m"
|
||||
input_text = "Mars colonists' favorite meals are"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
|
||||
r"""
|
||||
Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default.
|
||||
See ExtendedSerializationTest class for more params combinations.
|
||||
"""
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
self.quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type=quant_type,
|
||||
bnb_4bit_use_double_quant=double_quant,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
model_0 = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=self.quantization_config,
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
|
||||
|
||||
config = AutoConfig.from_pretrained(tmpdirname)
|
||||
self.assertTrue(hasattr(config, "quantization_config"))
|
||||
|
||||
model_1 = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
|
||||
|
||||
# checking quantized linear module weight
|
||||
linear = get_some_linear_layer(model_1)
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
|
||||
self.assertTrue(hasattr(linear.weight, "quant_state"))
|
||||
self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
|
||||
|
||||
# checking memory footpring
|
||||
self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
|
||||
|
||||
# Matching all parameters and their quant_state items:
|
||||
d0 = dict(model_0.named_parameters())
|
||||
d1 = dict(model_1.named_parameters())
|
||||
self.assertTrue(d0.keys() == d1.keys())
|
||||
|
||||
for k in d0.keys():
|
||||
self.assertTrue(d0[k].shape == d1[k].shape)
|
||||
self.assertTrue(d0[k].device.type == d1[k].device.type)
|
||||
self.assertTrue(d0[k].device == d1[k].device)
|
||||
self.assertTrue(d0[k].dtype == d1[k].dtype)
|
||||
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
|
||||
|
||||
if isinstance(d0[k], bnb.nn.modules.Params4bit):
|
||||
for v0, v1 in zip(
|
||||
d0[k].quant_state.as_dict().values(),
|
||||
d1[k].quant_state.as_dict().values(),
|
||||
):
|
||||
if isinstance(v0, torch.Tensor):
|
||||
self.assertTrue(torch.equal(v0, v1.to(v0.device)))
|
||||
else:
|
||||
self.assertTrue(v0 == v1)
|
||||
|
||||
# comparing forward() outputs
|
||||
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
out_0 = model_0(**encoded_input)
|
||||
out_1 = model_1(**encoded_input)
|
||||
self.assertTrue(torch.equal(out_0["logits"], out_1["logits"]))
|
||||
|
||||
# comparing generate() outputs
|
||||
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
output_sequences_0 = model_0.generate(**encoded_input, max_new_tokens=10)
|
||||
output_sequences_1 = model_1.generate(**encoded_input, max_new_tokens=10)
|
||||
|
||||
def _decode(token):
|
||||
return tokenizer.decode(token, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(
|
||||
[_decode(x) for x in output_sequences_0],
|
||||
[_decode(x) for x in output_sequences_1],
|
||||
)
|
||||
|
||||
|
||||
class ExtendedSerializationTest(BaseSerializationTest):
|
||||
"""
|
||||
tests more combinations of parameters
|
||||
"""
|
||||
|
||||
def test_nf4_single_unsafe(self):
|
||||
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False)
|
||||
|
||||
def test_nf4_single_safe(self):
|
||||
self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True)
|
||||
|
||||
def test_nf4_double_unsafe(self):
|
||||
self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False)
|
||||
|
||||
# nf4 double safetensors quantization is tested in test_serialization() method from the parent class
|
||||
|
||||
def test_fp4_single_unsafe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False)
|
||||
|
||||
def test_fp4_single_safe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True)
|
||||
|
||||
def test_fp4_double_unsafe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False)
|
||||
|
||||
def test_fp4_double_safe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
|
||||
|
||||
|
||||
class BloomSerializationTest(BaseSerializationTest):
|
||||
"""
|
||||
default BaseSerializationTest config tested with Bloom family model
|
||||
"""
|
||||
|
||||
model_name = "bigscience/bloom-560m"
|
||||
|
||||
|
||||
class GPTSerializationTest(BaseSerializationTest):
|
||||
"""
|
||||
default BaseSerializationTest config tested with GPT family model
|
||||
"""
|
||||
|
||||
model_name = "gpt2-xl"
|
||||
|
||||
Reference in New Issue
Block a user