[RWKV] Final fix RWMV 4bit (#26134)
* Final fix RWMV 4bit * fixup * add a test * add more clarifications
This commit is contained in:
@@ -31,6 +31,7 @@ from ...utils import (
|
|||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_bitsandbytes_available,
|
||||||
is_ninja_available,
|
is_ninja_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -735,18 +736,35 @@ class RwkvModel(RwkvPreTrainedModel):
|
|||||||
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
elif hasattr(block.attention.output.weight, "quant_state"):
|
elif hasattr(block.attention.output.weight, "quant_state"):
|
||||||
block.attention.output.weight.quant_state[0].div_(
|
self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
|
||||||
2 ** int(block_id // self.config.rescale_every)
|
self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
|
||||||
)
|
|
||||||
block.feed_forward.value.weight.quant_state[0].div_(
|
|
||||||
2 ** int(block_id // self.config.rescale_every)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
|
|
||||||
self.layers_are_rescaled = not self.training
|
self.layers_are_rescaled = not self.training
|
||||||
|
|
||||||
|
def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
|
||||||
|
r"""
|
||||||
|
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
|
||||||
|
be quantized again.
|
||||||
|
"""
|
||||||
|
if not is_bitsandbytes_available():
|
||||||
|
raise ImportError("Please install bitsandbytes to use this method.")
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
|
||||||
|
|
||||||
|
dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
|
||||||
|
|
||||||
|
# re-quantize the model:
|
||||||
|
# we need to put it first on CPU then back to the device
|
||||||
|
# this will create an overhead :/
|
||||||
|
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
|
||||||
|
# bugs with bnb
|
||||||
|
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
|
||||||
|
setattr(target_layer, "weight", quant_weight)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -172,6 +172,22 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
# 4-bit parameters are packed in uint8 variables
|
# 4-bit parameters are packed in uint8 variables
|
||||||
self.assertTrue(module.weight.dtype == torch.uint8)
|
self.assertTrue(module.weight.dtype == torch.uint8)
|
||||||
|
|
||||||
|
def test_rwkv_4bit(self):
|
||||||
|
r"""
|
||||||
|
A simple test to check if 4-bit RWKV inference works as expected.
|
||||||
|
"""
|
||||||
|
model_id = "RWKV/rwkv-4-169m-pile"
|
||||||
|
|
||||||
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
|
||||||
|
tok = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
text = "Hello my name is"
|
||||||
|
input_ids = tok.encode(text, return_tensors="pt").to(0)
|
||||||
|
|
||||||
|
_ = model.generate(input_ids, max_new_tokens=30)
|
||||||
|
|
||||||
def test_generate_quality(self):
|
def test_generate_quality(self):
|
||||||
r"""
|
r"""
|
||||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||||
|
|||||||
Reference in New Issue
Block a user