[RWKV] Final fix RWMV 4bit (#26134)

* Final fix RWMV 4bit

* fixup

* add a test

* add more clarifications
This commit is contained in:
Younes Belkada
2023-09-13 16:30:20 +02:00
committed by GitHub
parent 32ec7345f2
commit 7ccac73f74
2 changed files with 40 additions and 6 deletions

View File

@@ -172,6 +172,22 @@ class Bnb4BitTest(Base4bitTest):
# 4-bit parameters are packed in uint8 variables
self.assertTrue(module.weight.dtype == torch.uint8)
def test_rwkv_4bit(self):
r"""
A simple test to check if 4-bit RWKV inference works as expected.
"""
model_id = "RWKV/rwkv-4-169m-pile"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
tok = AutoTokenizer.from_pretrained(model_id)
text = "Hello my name is"
input_ids = tok.encode(text, return_tensors="pt").to(0)
_ = model.generate(input_ids, max_new_tokens=30)
def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.