diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index f6fa587fb6..930fdfb799 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -51,6 +51,9 @@ RUN python3 -m pip install --no-cache-dir gguf # Some slow tests require bnb RUN python3 -m pip install --no-cache-dir bitsandbytes +# Some tests require quanto +RUN python3 -m pip install --no-cache-dir quanto + # For `dinat` model # The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent) RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 6c7c70cb14..b000cc0677 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -174,6 +174,43 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te ``` +## KV Cache Quantization + +The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value +cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models. +Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed. + +KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache] +(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper. + +To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`. +Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class. +One has to indicate which quantization backend to use in the [`QuantizedCacheConfig`], the default is `quanto`. + + + +Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization. + + + + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM + +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0") +>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device) + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"}) +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +I like rock music because it's loud and energetic. It's a great way to express myself and rel + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20) +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +I like rock music because it's loud and energetic. I like to listen to it when I'm feeling +``` + ## Watermarking The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green". diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 04a4428a00..5bf8b5c4a0 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -360,6 +360,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] Cache - update +[[autodoc]] CacheConfig + - update + +[[autodoc]] QuantizedCacheConfig + - validate + [[autodoc]] DynamicCache - update - get_seq_length @@ -367,6 +373,14 @@ A [`Constraint`] can be used to force the generation to include specific tokens - to_legacy_cache - from_legacy_cache +[[autodoc]] QuantizedCache + - update + - get_seq_length + +[[autodoc]] QuantoQuantizedCache + +[[autodoc]] HQQQuantizedCache + [[autodoc]] SinkCache - update - get_seq_length @@ -375,7 +389,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] StaticCache - update - get_seq_length - - reorder_cache + - reset ## Watermark Utils diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4255e30379..8da7a8b3e3 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1182,7 +1182,17 @@ else: _import_structure["activations"] = [] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] - _import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"] + _import_structure["cache_utils"] = [ + "Cache", + "CacheConfig", + "DynamicCache", + "HQQQuantizedCache", + "QuantizedCache", + "QuantizedCacheConfig", + "QuantoQuantizedCache", + "SinkCache", + "StaticCache", + ] _import_structure["data.datasets"] = [ "GlueDataset", "GlueDataTrainingArguments", @@ -5792,7 +5802,17 @@ if TYPE_CHECKING: # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments - from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache + from .cache_utils import ( + Cache, + CacheConfig, + DynamicCache, + HQQQuantizedCache, + QuantizedCache, + QuantizedCacheConfig, + QuantoQuantizedCache, + SinkCache, + StaticCache, + ) from .data.datasets import ( GlueDataset, GlueDataTrainingArguments, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 990a863e18..ad91edfcbb 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,12 +1,21 @@ +import copy +import json +import os from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch from .configuration_utils import PretrainedConfig -from .utils import logging +from .utils import is_hqq_available, is_quanto_available, logging +if is_quanto_available(): + from quanto import QBitsTensor, qint2, qint4 + +if is_hqq_available(): + from hqq.core.quantize import Quantizer as HQQQuantizer + logger = logging.get_logger(__name__) @@ -82,6 +91,201 @@ class Cache: return None +@dataclass +class CacheConfig: + """ + Base class for cache configs + """ + + cache_implementation: None + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class QuantizedCacheConfig(CacheConfig): + """ + Configuration class for quantized cache settings. + + Attributes: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original presicion. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to peform computations, should be same as the model's device. + """ + + def __init__( + self, + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", + ): + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), + ) + + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) + + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) + + class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. @@ -186,6 +390,168 @@ class DynamicCache(Cache): return cache +class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + self._quantized_key_cache: List[torch.Tensor] = [] + self._quantized_value_cache: List[torch.Tensor] = [] + + self.nbits = cache_config.nbits + self.residual_length = cache_config.residual_length + self.q_group_size = cache_config.q_group_size + self.axis_key = cache_config.axis_key + self.axis_value = cache_config.axis_value + self.compute_dtype = cache_config.compute_dtype + self.device = cache_config.device + + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if len(self.key_cache) <= layer_idx: + self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) + self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) + self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) + self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) + keys_to_return, values_to_return = key_states, value_states + else: + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] + values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] + + keys_to_return = torch.cat(keys_to_return, dim=-2) + values_to_return = torch.cat(values_to_return, dim=-2) + if ( + self.key_cache[layer_idx].dim() == 4 + and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length + ): + self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) + self._quantized_value_cache[layer_idx] = self._quantize( + values_to_return.contiguous(), axis=self.axis_value + ) + self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) + self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return keys_to_return, values_to_return + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is + # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to verify attn_weight shape in some models + return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 + + def _quantize(self, tensor, axis): + """Quantizes a key/value using a defined quantization method.""" + raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") + + def _dequantize(self, q_tensor): + """Dequantizes back the tensor that was quantized by `self._quantize()`""" + raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") + + +class QuantoQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + + Parameters: + cache_config (`QuantizedCacheConfig`,): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + if self.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") + + if self.axis_key not in [0, -1]: + raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") + + if self.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" + ) + + self.qtype = qint4 if self.nbits == 4 else qint2 + + def _quantize(self, tensor, axis): + qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size) + return qtensor + + def _dequantize(self, qtensor): + return qtensor.dequantize() + + +class HQQQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + + Parameters: + cache_config (`QuantizedCacheConfig`,): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" + ) + + if self.axis_key not in [0, 1]: + raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + + if self.axis_value not in [0, 1]: + raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") + + self.quantizer = HQQQuantizer + + def _quantize(self, tensor, axis): + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.device, + compute_dtype=self.compute_dtype, + nbits=self.nbits, + group_size=self.q_group_size, + ) + meta["compute_dtype"] = self.compute_dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype + return qtensor, meta + + def _dequantize(self, qtensor): + quant_tensor, meta = qtensor + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor + + class SinkCache(Cache): """ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index eb14c60d9a..0d1eba0bd5 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -31,6 +31,7 @@ from ..utils import ( download_url, extract_commit_hash, is_remote_url, + is_torch_available, logging, ) @@ -41,6 +42,12 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") +NEEDS_CACHE_CONFIG = {} + +if is_torch_available(): + from ..cache_utils import QuantizedCacheConfig + + NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig class GenerationMode(ExplicitEnum): @@ -299,6 +306,10 @@ class GenerationConfig(PushToHubMixin): cache_implementation (`str`, *optional*, default to `None`): Cache class that should be used when generating. + cache_config (`Union[CacheConfig, dict]`, *optional*, default to `None`): + Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and + it will be converted to its repsective `CacheConfig` internally. + Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. > Wild card @@ -382,6 +393,13 @@ class GenerationConfig(PushToHubMixin): # Cache implementation self.cache_implementation = kwargs.pop("cache_implementation", None) + self.cache_config = kwargs.pop("cache_config", None) + if self.cache_implementation is not None: + cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] + if self.cache_config is None: + self.cache_config = cache_config_class() + elif isinstance(self.cache_config, dict): + self.cache_config = cache_config_class.from_dict(self.cache_config) # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) @@ -638,13 +656,26 @@ class GenerationConfig(PushToHubMixin): f"({self.num_beams})." ) - # check watermarking arguments + # 5. check `cache_config` + if self.cache_config is not None: + cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation) + if cache_class is None: + raise ValueError( + "You provided a `cache_config` but the cache implementation you are using " + f"({self.cache_implementation}) does not require any config. Make sure to use the " + "correct cache implementation matching your cache config." + ) + if not isinstance(self.cache_config, cache_class): + self.cache_config = cache_class.from_dict(self.cache_config) + self.cache_config.validate() + + # 6. check watermarking arguments if self.watermarking_config is not None: if not isinstance(self.watermarking_config, WatermarkingConfig): self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) self.watermarking_config.validate() - # 5. check common issue: passing `generate` arguments inside the generation config + # 7. check common issue: passing `generate` arguments inside the generation config generate_arguments = ( "logits_processor", "stopping_criteria", diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 149ce144e6..84c9dd995e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,7 +24,15 @@ import torch import torch.distributed as dist from torch import nn -from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ..cache_utils import ( + Cache, + DynamicCache, + HQQQuantizedCache, + QuantizedCacheConfig, + QuantoQuantizedCache, + SlidingWindowCache, + StaticCache, +) from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -34,7 +42,14 @@ from ..models.auto import ( MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging +from ..utils import ( + ModelOutput, + is_accelerate_available, + is_hqq_available, + is_quanto_available, + is_torchdynamo_compiling, + logging, +) from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -97,6 +112,7 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} +QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} @dataclass @@ -1658,20 +1674,43 @@ class GenerationMixin: "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " "Cache object) is unsupported. Please use only one of the two." ) - elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if not self._supports_cache_class: - raise ValueError( - "This model does not support the `cache_implementation` argument. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981." + elif generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs["past_key_values"] = self._get_cache( + generation_config.cache_implementation, batch_size, generation_config.max_length ) - if generation_config.cache_implementation == "static" and not self._supports_static_cache: - raise ValueError( - "This model does not support `cache_implementation='static'`. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981" + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() ) - model_kwargs["past_key_values"] = self._get_cache( - generation_config.cache_implementation, batch_size, generation_config.max_length - ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs["past_key_values"] = cache_class(cache_config) + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. determine generation mode diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 106f79ae8e..354962bab0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1284,6 +1284,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _supports_cache_class = False _supports_static_cache = False + # Has support for a `QuantoQuantizedCache` instance as `past_key_values` + _supports_quantized_cache = False + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41c4e151a3..7d1b0e19fc 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -712,6 +712,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 850e3a3f81..67f4b819e9 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -937,6 +937,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module: nn.Module): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 565a976fd7..474dccf308 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -698,6 +698,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1f4f6ac9a0..226d14c18b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -767,6 +767,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 9b4b08239b..1630297cd8 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -745,6 +745,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index c48132c83c..ab9f8c3d85 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -539,6 +539,8 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False _supports_sdpa = False # we can't compare with eager for now + _supports_cache_class = True + _supports_quantized_cache = True def _init_weights(self, module): std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index df4e922c5a..160d70fe61 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -832,6 +832,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_cache_class = True _supports_sdpa = True + _supports_quantized_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index a42d7b7dec..f30cfe1947 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -784,11 +784,11 @@ class WhisperGenerationMixin: del generate_kwargs[key] seek_outputs = super().generate( segment_input, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, decoder_input_ids=decoder_input_ids, **generate_kwargs, ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5e00230aed..5ac2a2ccbd 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -23,6 +23,13 @@ class Cache(metaclass=DummyObject): requires_backends(self, ["torch"]) +class CacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class DynamicCache(metaclass=DummyObject): _backends = ["torch"] @@ -30,6 +37,34 @@ class DynamicCache(metaclass=DummyObject): requires_backends(self, ["torch"]) +class HQQQuantizedCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class QuantizedCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class QuantizedCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class QuantoQuantizedCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SinkCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 840b64e17d..7d654312a3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed from transformers.testing_utils import ( is_flaky, require_accelerate, + require_quanto, require_torch, require_torch_multi_accelerator, slow, @@ -55,7 +56,7 @@ if is_torch_available(): ImageGPTForCausalImageModeling, SpeechEncoderDecoderModel, ) - from transformers.cache_utils import DynamicCache + from transformers.cache_utils import DynamicCache, QuantoQuantizedCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1654,6 +1655,39 @@ class GenerationTesterMixin: ) ) + @require_quanto + def test_generate_with_quant_cache(self): + for model_class in self.all_generative_model_classes: + if not model_class._supports_quantized_cache: + self.skipTest("This model does not support the quantized cache format") + + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_new_tokens": 5, + "cache_implementation": "quantized", + # careful with group size, should be divisor of model's hidden size + "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128}, + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache)) + + # passing past key values of different type should raise Error + with self.assertRaises(ValueError): + model.generate( + input_ids, attention_mask=attention_mask, past_key_valyes=DynamicCache(), **generation_kwargs + ) + + # setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense + generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128} + with self.assertRaises(ValueError): + model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/quantization/quanto_integration/test_quanto.py b/tests/quantization/quanto_integration/test_quanto.py index 69bf998ace..f574478241 100644 --- a/tests/quantization/quanto_integration/test_quanto.py +++ b/tests/quantization/quanto_integration/test_quanto.py @@ -17,13 +17,22 @@ import tempfile import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig -from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow +from transformers.testing_utils import ( + require_accelerate, + require_quanto, + require_read_token, + require_torch_gpu, + slow, + torch_device, +) from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available if is_torch_available(): import torch + from transformers import LlamaForCausalLM, LlamaTokenizer + if is_accelerate_available(): from accelerate import init_empty_weights @@ -429,3 +438,28 @@ class QuantoQuantizationActivationTest(unittest.TestCase): with self.assertRaises(ValueError) as e: AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config) self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception)) + + +@require_torch_gpu +class QuantoKVCacheQuantizationTest(unittest.TestCase): + @slow + @require_read_token + def test_quantized_cache(self): + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my burgers, my hot dogs, my sandwiches, my chicken, my pizza, my sal", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="left") + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(torch_device) + + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False, cache_implementation="quantized") + text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text)