Quantized KV Cache (#30483)
* clean-up * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * more suggestions * mapping if torch available * run tests & add 'support_quantized' flag * fix jamba test * revert, will be fixed by another PR * codestyle * HQQ and versatile cache classes * final update * typo * make tests happy --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e05baad861
commit
d583f1317b
@@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_quanto,
|
||||
require_torch,
|
||||
require_torch_multi_accelerator,
|
||||
slow,
|
||||
@@ -55,7 +56,7 @@ if is_torch_available():
|
||||
ImageGPTForCausalImageModeling,
|
||||
SpeechEncoderDecoderModel,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@@ -1654,6 +1655,39 @@ class GenerationTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@require_quanto
|
||||
def test_generate_with_quant_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_quantized_cache:
|
||||
self.skipTest("This model does not support the quantized cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 5,
|
||||
"cache_implementation": "quantized",
|
||||
# careful with group size, should be divisor of model's hidden size
|
||||
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
|
||||
|
||||
# passing past key values of different type should raise Error
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(
|
||||
input_ids, attention_mask=attention_mask, past_key_valyes=DynamicCache(), **generation_kwargs
|
||||
)
|
||||
|
||||
# setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
|
||||
generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128}
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
Reference in New Issue
Block a user