From 7ccac73f749ce535851b9188f3867d5ed87c318c Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 13 Sep 2023 16:30:20 +0200 Subject: [PATCH] [`RWKV`] Final fix RWMV 4bit (#26134) * Final fix RWMV 4bit * fixup * add a test * add more clarifications --- src/transformers/models/rwkv/modeling_rwkv.py | 30 +++++++++++++++---- tests/quantization/bnb/test_4bit.py | 16 ++++++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 2d20590628..db41bd3c95 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -31,6 +31,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_bitsandbytes_available, is_ninja_available, is_torch_cuda_available, logging, @@ -735,18 +736,35 @@ class RwkvModel(RwkvPreTrainedModel): block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) elif hasattr(block.attention.output.weight, "quant_state"): - block.attention.output.weight.quant_state[0].div_( - 2 ** int(block_id // self.config.rescale_every) - ) - block.feed_forward.value.weight.quant_state[0].div_( - 2 ** int(block_id // self.config.rescale_every) - ) + self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id) + self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id) else: block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every)) block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every)) self.layers_are_rescaled = not self.training + def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): + r""" + Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will + be quantized again. + """ + if not is_bitsandbytes_available(): + raise ImportError("Please install bitsandbytes to use this method.") + import bitsandbytes as bnb + + dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state) + + dequant_weights.div_(2 ** int(block_id // self.config.rescale_every)) + + # re-quantize the model: + # we need to put it first on CPU then back to the device + # this will create an overhead :/ + # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid + # bugs with bnb + quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device) + setattr(target_layer, "weight", quant_weight) + @add_start_docstrings( """ diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index ce1dd336e9..801173da79 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -172,6 +172,22 @@ class Bnb4BitTest(Base4bitTest): # 4-bit parameters are packed in uint8 variables self.assertTrue(module.weight.dtype == torch.uint8) + def test_rwkv_4bit(self): + r""" + A simple test to check if 4-bit RWKV inference works as expected. + """ + model_id = "RWKV/rwkv-4-169m-pile" + + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True) + + model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) + tok = AutoTokenizer.from_pretrained(model_id) + + text = "Hello my name is" + input_ids = tok.encode(text, return_tensors="pt").to(0) + + _ = model.generate(input_ids, max_new_tokens=30) + def test_generate_quality(self): r""" Test the generation quality of the quantized model and see that we are matching the expected output.