From 4f7806ef7e6d717485a514c29fe60f09871a58f8 Mon Sep 17 00:00:00 2001 From: Poedator <24738311+poedator@users.noreply.github.com> Date: Thu, 21 Dec 2023 13:54:44 +0300 Subject: [PATCH] [bnb] Let's make serialization of 4bit models possible (#26037) * updated bitsandbytes.py * rm test_raise_* from test_4bit.py * add test_4bit_serialization.py * modeling_utils bulk edits * bnb_ver 0.41.3 in integrations/bitsandbytes.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * @slow reinstated Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * bnb ver 0.41.3 in src/transformers/modeling_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * rm bnb version todo in integrations/bitsandbytes.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * moved 4b serialization tests to test_4bit * tests upd for opt * to torch_device Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * ruff fixes to tests * rm redundant bnb version check in mod_utils Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * restore _hf_peft_config_loaded modeling_utils.py::2188 Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * restore _hf_peft_config_loaded test in modeling_utils.py::2199 Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fixed NOT getattr(self, "is_8bit_serializable") Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * setting model.is_4bit_serializable * rm separate fp16_statistics arg from set_module... * rm else branch in integrations::bnb::set_module * bnb 4bit dtype check * upd comment on 4bit weights * upd tests for FP4 safe --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/integrations/bitsandbytes.py | 68 ++++--- src/transformers/modeling_utils.py | 96 +++++++--- tests/quantization/bnb/test_4bit.py | 177 +++++++++++++++--- 3 files changed, 264 insertions(+), 77 deletions(-) diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py index 1a8220b1ed..43aeaf6708 100644 --- a/src/transformers/integrations/bitsandbytes.py +++ b/src/transformers/integrations/bitsandbytes.py @@ -21,7 +21,7 @@ if is_accelerate_available(): logger = logging.get_logger(__name__) -def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None): +def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None): """ A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The @@ -37,8 +37,8 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non The device on which to set the tensor. value (`torch.Tensor`, *optional*): The value of the tensor (useful when going from the meta device to any other device). - fp16_statistics (`torch.HalfTensor`, *optional*): - The list of fp16 statistics to set on the module, used for serialization. + quantized_stats (`dict[str, Any]`, *optional*): + Dict with items for either 4-bit or 8-bit serialization """ # Recurse if needed if "." in tensor_name: @@ -58,8 +58,7 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") - is_4bit = False - is_8bit = False + prequantized_loading = quantized_stats is not None if is_buffer or not is_bitsandbytes_available(): is_8bit = False is_4bit = False @@ -74,32 +73,53 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non new_value = old_value.to(device) elif isinstance(value, torch.Tensor): new_value = value.to("cpu") - if value.dtype == torch.int8: - is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( - "0.37.2" - ) - if not is_8bit_serializable: - raise ValueError( - "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " - "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." - ) else: new_value = torch.tensor(value, device="cpu") # Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization. # Since weights are saved in the correct "orientation", we skip transposing when loading. - if issubclass(module.source_cls, Conv1D) and fp16_statistics is None: + if issubclass(module.source_cls, Conv1D) and not prequantized_loading: new_value = new_value.T kwargs = old_value.__dict__ - if is_8bit: - new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) - elif is_4bit: - new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) + if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)): + raise ValueError( + f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status." + ) + + if is_8bit: + is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( + "0.37.2" + ) + if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) + if prequantized_loading: + setattr(new_value, "SCB", quantized_stats["SCB"].to(device)) + elif is_4bit: + if prequantized_loading: + is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( + "0.41.3" + ) + if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable: + raise ValueError( + "Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + new_value = bnb.nn.Params4bit.from_prequantized( + data=new_value, + quantized_stats=quantized_stats, + requires_grad=False, + device=device, + **kwargs, + ) + else: + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) module._parameters[tensor_name] = new_value - if fp16_statistics is not None: - setattr(module.weight, "SCB", fp16_statistics.to(device)) else: if value is None: @@ -117,7 +137,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non def _replace_with_bnb_linear( - model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, ): """ Private method that wraps the recursion for module replacement. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3c7c7b48d0..afdd09597b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -675,6 +675,7 @@ def _load_state_dict_into_meta_model( is_quantized=False, is_safetensors=False, keep_in_fp32_modules=None, + unexpected_keys=None, # passing `unexpected` for cleanup from quantization items ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its @@ -776,17 +777,41 @@ def _load_state_dict_into_meta_model( elif not is_quantized: # For backward compatibility with older versions of `accelerate` set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) - else: - if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys(): - fp16_statistics = state_dict[param_name.replace("weight", "SCB")] - else: - fp16_statistics = None + elif param.dtype in (torch.int8, torch.uint8) and is_quantized: + # handling newly quantized weights and loaded quantized weights + # edit the param.dtype restrictions and is_quantized condition when adding new quant methods + quantized_stats = {} + + if (param_name + ".quant_state.bitsandbytes__fp4" in state_dict) or ( + param_name + ".quant_state.bitsandbytes__nf4" in state_dict + ): + # 4bit loading. Collecting components for restoring quantized weight + # This can be expanded to make a universal call for any quantized weight loading + for k, v in state_dict.items(): + if param_name + "." in k: + quantized_stats[k] = v + unexpected_keys.remove(k) - if "SCB" not in param_name: set_module_quantized_tensor_to_device( - model, param_name, param_device, value=param, fp16_statistics=fp16_statistics + model, param_name, param_device, value=param, quantized_stats=quantized_stats ) + elif param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys(): + # 8bit loading. Could be combined with the above 4bit call. + # condition looks unreliable + fp16_statistics_key = param_name.replace("weight", "SCB") + unexpected_keys.remove(fp16_statistics_key) + set_module_quantized_tensor_to_device( + model, + param_name, + param_device, + value=param, + quantized_stats={"SCB": state_dict[fp16_statistics_key]}, + ) + else: + # loading not quantized params in quantized model + set_module_quantized_tensor_to_device(model, param_name, param_device, value=param) + return error_msgs, offload_index, state_dict_index @@ -2197,15 +2222,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix and not getattr(self, "is_8bit_serializable", False) and not _hf_peft_config_loaded ): - raise ValueError( - "You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected" - " behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed." + raise NotImplementedError( + "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. " + "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed." ) - # If the model has adapters attached, you can save the adapters - if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded: + if ( + getattr(self, "is_loaded_in_4bit", False) + and not getattr(self, "is_4bit_serializable", False) + and not _hf_peft_config_loaded + ): raise NotImplementedError( - "You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported" + "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. " + "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed." ) if getattr(self, "_awq_is_fused", False): @@ -2774,8 +2803,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix use_safetensors = False if is_bitsandbytes_available(): + is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3") is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2") else: + is_4bit_serializable = False is_8bit_serializable = False if trust_remote_code is True: @@ -3064,10 +3095,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if low_cpu_mem_usage is None: low_cpu_mem_usage = True - if ( - is_8bit_serializable - and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES - and load_in_8bit + if quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and ( + (is_8bit_serializable and load_in_8bit) or (is_4bit_serializable and load_in_4bit) ): if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES: logger.warning( @@ -3077,8 +3106,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) config.quantization_config = quantization_config elif ( - is_8bit_serializable - and not load_in_8bit + (is_8bit_serializable or is_4bit_serializable) + and not (load_in_8bit or load_in_4bit) and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES ): quantization_config = config.quantization_config @@ -3093,8 +3122,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) load_in_8bit = quantization_config.load_in_8bit + load_in_4bit = quantization_config.load_in_4bit - if load_in_8bit: + if load_in_8bit or load_in_4bit: if torch_dtype is None: torch_dtype = torch.float16 if device_map is None: @@ -3112,12 +3142,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix elif ( not is_8bit_serializable - and not load_in_8bit + and not (load_in_8bit or load_in_4bit) and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES ): logger.warning( "Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct" - " `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with " + " `bitsandbytes` version to support 4 and 8 bit serialization. Please install the latest version of `bitsandbytes` with " " `pip install --upgrade bitsandbytes`." ) @@ -3525,6 +3555,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix config.quantization_config = quantization_config model.is_8bit_serializable = is_8bit_serializable + model.is_4bit_serializable = is_4bit_serializable if load_in_8bit and torch_dtype is None: logger.warning( @@ -4018,10 +4049,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] + if ( + state_dict[checkpoint_key].shape[-1] == 1 + and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() + ): + # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. + # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. + pass + else: + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] return mismatched_keys if resolved_archive_file is not None: @@ -4130,6 +4169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix is_quantized=is_quantized, is_safetensors=is_safetensors, keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, ) error_msgs += new_error_msgs else: @@ -4167,10 +4207,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - if is_quantized: - unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem] - missing_keys = [elem for elem in missing_keys if "SCB" not in elem] - if len(unexpected_keys) > 0: archs = [] if model.config.architectures is None else model.config.architectures warner = logger.warning if model.__class__.__name__ in archs else logger.info diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index a70e3d8832..5e034e49f9 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -20,6 +20,7 @@ import unittest from packaging import version from transformers import ( + AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, @@ -29,6 +30,7 @@ from transformers import ( pipeline, ) from transformers.testing_utils import ( + is_bitsandbytes_available, is_torch_available, require_accelerate, require_bitsandbytes, @@ -36,13 +38,21 @@ from transformers.testing_utils import ( require_torch_gpu, require_torch_multi_gpu, slow, + torch_device, ) def get_some_linear_layer(model): if model.config.model_type == "gpt2": return model.transformer.h[0].mlp.c_fc - return model.transformer.h[0].mlp.dense_4h_to_h + elif model.config.model_type == "opt": + try: + return model.decoder.layers[0].fc1 + except AttributeError: + # for AutoModelforCausalLM + return model.model.decoder.layers[0].fc1 + else: + return model.transformer.h[0].mlp.dense_4h_to_h if is_torch_available(): @@ -68,6 +78,10 @@ if is_torch_available(): return self.module(input, *args, **kwargs) + self.adapter(input) +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + @require_bitsandbytes @require_accelerate @require_torch @@ -225,28 +239,6 @@ class Bnb4BitTest(Base4bitTest): self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) - def test_raise_on_save_pretrained(self): - r""" - Test whether trying to save a model after converting it in 8-bit will throw a warning. - """ - with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname: - self.model_4bit.save_pretrained(tmpdirname) - - def test_raise_if_config_and_load_in_4bit(self): - r""" - Test that loading the model with the config and `load_in_4bit` raises an error - """ - bnb_config = BitsAndBytesConfig() - - with self.assertRaises(ValueError): - _ = AutoModelForCausalLM.from_pretrained( - self.model_name, - quantization_config=bnb_config, - load_in_4bit=True, - device_map="auto", - bnb_4bit_quant_type="nf4", - ) - def test_device_and_dtype_assignment(self): r""" Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. @@ -346,8 +338,6 @@ class Bnb4BitT5Test(unittest.TestCase): `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test both cases. """ - import bitsandbytes as bnb - from transformers import T5ForConditionalGeneration # test with `t5-small` @@ -521,3 +511,140 @@ class Bnb4BitTestTraining(Base4bitTest): class Bnb4BitGPT2Test(Bnb4BitTest): model_name = "gpt2-xl" EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187 + + +@require_bitsandbytes +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class BaseSerializationTest(unittest.TestCase): + model_name = "facebook/opt-125m" + input_text = "Mars colonists' favorite meals are" + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True): + r""" + Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default. + See ExtendedSerializationTest class for more params combinations. + """ + + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + self.quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type=quant_type, + bnb_4bit_use_double_quant=double_quant, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + model_0 = AutoModelForCausalLM.from_pretrained( + self.model_name, + quantization_config=self.quantization_config, + device_map=torch_device, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization) + + config = AutoConfig.from_pretrained(tmpdirname) + self.assertTrue(hasattr(config, "quantization_config")) + + model_1 = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(hasattr(linear.weight, "quant_state")) + self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) + + # checking memory footpring + self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) + + # Matching all parameters and their quant_state items: + d0 = dict(model_0.named_parameters()) + d1 = dict(model_1.named_parameters()) + self.assertTrue(d0.keys() == d1.keys()) + + for k in d0.keys(): + self.assertTrue(d0[k].shape == d1[k].shape) + self.assertTrue(d0[k].device.type == d1[k].device.type) + self.assertTrue(d0[k].device == d1[k].device) + self.assertTrue(d0[k].dtype == d1[k].dtype) + self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device))) + + if isinstance(d0[k], bnb.nn.modules.Params4bit): + for v0, v1 in zip( + d0[k].quant_state.as_dict().values(), + d1[k].quant_state.as_dict().values(), + ): + if isinstance(v0, torch.Tensor): + self.assertTrue(torch.equal(v0, v1.to(v0.device))) + else: + self.assertTrue(v0 == v1) + + # comparing forward() outputs + encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + out_0 = model_0(**encoded_input) + out_1 = model_1(**encoded_input) + self.assertTrue(torch.equal(out_0["logits"], out_1["logits"])) + + # comparing generate() outputs + encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + output_sequences_0 = model_0.generate(**encoded_input, max_new_tokens=10) + output_sequences_1 = model_1.generate(**encoded_input, max_new_tokens=10) + + def _decode(token): + return tokenizer.decode(token, skip_special_tokens=True) + + self.assertEqual( + [_decode(x) for x in output_sequences_0], + [_decode(x) for x in output_sequences_1], + ) + + +class ExtendedSerializationTest(BaseSerializationTest): + """ + tests more combinations of parameters + """ + + def test_nf4_single_unsafe(self): + self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False) + + def test_nf4_single_safe(self): + self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True) + + def test_nf4_double_unsafe(self): + self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False) + + # nf4 double safetensors quantization is tested in test_serialization() method from the parent class + + def test_fp4_single_unsafe(self): + self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False) + + def test_fp4_single_safe(self): + self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True) + + def test_fp4_double_unsafe(self): + self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False) + + def test_fp4_double_safe(self): + self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) + + +class BloomSerializationTest(BaseSerializationTest): + """ + default BaseSerializationTest config tested with Bloom family model + """ + + model_name = "bigscience/bloom-560m" + + +class GPTSerializationTest(BaseSerializationTest): + """ + default BaseSerializationTest config tested with GPT family model + """ + + model_name = "gpt2-xl"