Quantized KV Cache (#30483)

* clean-up

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* Update tests/quantization/quanto_integration/test_quanto.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* more suggestions

* mapping if torch available

* run tests & add 'support_quantized' flag

* fix jamba test

* revert, will be fixed by another PR

* codestyle

* HQQ and versatile cache classes

* final update

* typo

* make tests happy

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-05-23 17:25:20 +05:00
committed by GitHub
parent e05baad861
commit d583f1317b
19 changed files with 652 additions and 28 deletions

View File

@@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_quanto,
require_torch,
require_torch_multi_accelerator,
slow,
@@ -55,7 +56,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
@@ -1654,6 +1655,39 @@ class GenerationTesterMixin:
)
)
@require_quanto
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
if not model_class._supports_quantized_cache:
self.skipTest("This model does not support the quantized cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"cache_implementation": "quantized",
# careful with group size, should be divisor of model's hidden size
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
"return_dict_in_generate": True, # Required to return `past_key_values`
}
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
# passing past key values of different type should raise Error
with self.assertRaises(ValueError):
model.generate(
input_ids, attention_mask=attention_mask, past_key_valyes=DynamicCache(), **generation_kwargs
)
# setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128}
with self.assertRaises(ValueError):
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences

View File

@@ -17,13 +17,22 @@ import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow
from transformers.testing_utils import (
require_accelerate,
require_quanto,
require_read_token,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
if is_torch_available():
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -429,3 +438,28 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
with self.assertRaises(ValueError) as e:
AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config)
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
@require_torch_gpu
class QuantoKVCacheQuantizationTest(unittest.TestCase):
@slow
@require_read_token
def test_quantized_cache(self):
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my burgers, my hot dogs, my sandwiches, my chicken, my pizza, my sal",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="left")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(torch_device)
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False, cache_implementation="quantized")
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)