[RWKV] Final fix RWMV 4bit (#26134)
* Final fix RWMV 4bit * fixup * add a test * add more clarifications
This commit is contained in:
@@ -172,6 +172,22 @@ class Bnb4BitTest(Base4bitTest):
|
||||
# 4-bit parameters are packed in uint8 variables
|
||||
self.assertTrue(module.weight.dtype == torch.uint8)
|
||||
|
||||
def test_rwkv_4bit(self):
|
||||
r"""
|
||||
A simple test to check if 4-bit RWKV inference works as expected.
|
||||
"""
|
||||
model_id = "RWKV/rwkv-4-169m-pile"
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
text = "Hello my name is"
|
||||
input_ids = tok.encode(text, return_tensors="pt").to(0)
|
||||
|
||||
_ = model.generate(input_ids, max_new_tokens=30)
|
||||
|
||||
def test_generate_quality(self):
|
||||
r"""
|
||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||
|
||||
Reference in New Issue
Block a user