Quantized KV Cache (#30483)

* clean-up

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* Update tests/quantization/quanto_integration/test_quanto.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* more suggestions

* mapping if torch available

* run tests & add 'support_quantized' flag

* fix jamba test

* revert, will be fixed by another PR

* codestyle

* HQQ and versatile cache classes

* final update

* typo

* make tests happy

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-05-23 17:25:20 +05:00
committed by GitHub
parent e05baad861
commit d583f1317b
19 changed files with 652 additions and 28 deletions

View File

@@ -17,13 +17,22 @@ import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow
from transformers.testing_utils import (
require_accelerate,
require_quanto,
require_read_token,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
if is_torch_available():
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -429,3 +438,28 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
with self.assertRaises(ValueError) as e:
AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config)
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
@require_torch_gpu
class QuantoKVCacheQuantizationTest(unittest.TestCase):
@slow
@require_read_token
def test_quantized_cache(self):
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my burgers, my hot dogs, my sandwiches, my chicken, my pizza, my sal",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="left")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(torch_device)
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False, cache_implementation="quantized")
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)