Quantized KV Cache (#30483)
* clean-up * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * more suggestions * mapping if torch available * run tests & add 'support_quantized' flag * fix jamba test * revert, will be fixed by another PR * codestyle * HQQ and versatile cache classes * final update * typo * make tests happy --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e05baad861
commit
d583f1317b
@@ -17,13 +17,22 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
||||
from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_quanto,
|
||||
require_read_token,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
@@ -429,3 +438,28 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as e:
|
||||
AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config)
|
||||
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class QuantoKVCacheQuantizationTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_quantized_cache(self):
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my burgers, my hot dogs, my sandwiches, my chicken, my pizza, my sal",
|
||||
]
|
||||
|
||||
prompts = [
|
||||
"Simply put, the theory of relativity states that ",
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="left")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False, cache_implementation="quantized")
|
||||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
Reference in New Issue
Block a user