|
|
|
|
@@ -4,6 +4,7 @@ import importlib.metadata
|
|
|
|
|
import inspect
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from collections.abc import Iterable
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
|
@@ -24,7 +25,7 @@ if is_hqq_available():
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CacheLayerMixin:
|
|
|
|
|
class CacheLayerMixin(ABC):
|
|
|
|
|
"""Base, abstract class for a single layer's cache."""
|
|
|
|
|
|
|
|
|
|
is_compileable = False
|
|
|
|
|
@@ -32,26 +33,22 @@ class CacheLayerMixin:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.keys, self.values = None, None
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def update(
|
|
|
|
|
self,
|
|
|
|
|
key_states: torch.Tensor,
|
|
|
|
|
value_states: torch.Tensor,
|
|
|
|
|
cache_kwargs: Optional[dict[str, Any]] = None,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Updates KV cache, returns updated keys/values of the layer."""
|
|
|
|
|
raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.")
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
|
|
|
|
|
|
|
|
|
def get_seq_length(self, cache_position=None) -> int:
|
|
|
|
|
"""Returns the sequence length of this layer's cache."""
|
|
|
|
|
raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.")
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_seq_length(self, cache_position=None) -> int: ...
|
|
|
|
|
|
|
|
|
|
def get_max_cache_shape(self) -> int:
|
|
|
|
|
"""Returns the maximum sequence length (i.e. max capacity) of this layer's cache."""
|
|
|
|
|
raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.")
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_max_cache_shape(self) -> int: ...
|
|
|
|
|
|
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
|
|
|
|
|
"""Returns mask sizes for the layer."""
|
|
|
|
|
raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.")
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
|
|
|
|
|
|
|
|
|
|
def reset(self) -> None:
|
|
|
|
|
"""Resets the cache values while preserving the objects"""
|
|
|
|
|
@@ -76,26 +73,6 @@ class DynamicLayer(CacheLayerMixin):
|
|
|
|
|
See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer":
|
|
|
|
|
"""
|
|
|
|
|
Build a `DynamicLayer` instance from pre-existing key/value tensors.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
keys (`torch.Tensor`):
|
|
|
|
|
Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
|
|
|
|
|
values (`torch.Tensor`):
|
|
|
|
|
Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
`DynamicLayer`: The newly constructed layer whose internal cache directly references
|
|
|
|
|
the supplied tensors.
|
|
|
|
|
"""
|
|
|
|
|
layer = cls()
|
|
|
|
|
layer.keys = keys
|
|
|
|
|
layer.values = values
|
|
|
|
|
return layer
|
|
|
|
|
|
|
|
|
|
def update(
|
|
|
|
|
self,
|
|
|
|
|
key_states: torch.Tensor,
|
|
|
|
|
@@ -175,6 +152,26 @@ class DynamicLayer(CacheLayerMixin):
|
|
|
|
|
kv_length = query_length + past_seen_tokens
|
|
|
|
|
return kv_length, kv_offset
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer":
|
|
|
|
|
"""
|
|
|
|
|
Build a `DynamicLayer` instance from pre-existing key/value tensors.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
keys (`torch.Tensor`):
|
|
|
|
|
Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
|
|
|
|
|
values (`torch.Tensor`):
|
|
|
|
|
Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
`DynamicLayer`: The newly constructed layer whose internal cache directly references
|
|
|
|
|
the supplied tensors.
|
|
|
|
|
"""
|
|
|
|
|
layer = cls()
|
|
|
|
|
layer.keys = keys
|
|
|
|
|
layer.values = values
|
|
|
|
|
return layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StaticLayer(CacheLayerMixin):
|
|
|
|
|
"""
|
|
|
|
|
@@ -558,10 +555,10 @@ class OffloadedCacheProcessor(CacheProcessor):
|
|
|
|
|
self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers)
|
|
|
|
|
if self.is_static:
|
|
|
|
|
for i, layer in enumerate(cache.layers):
|
|
|
|
|
device = cache.layer_init_args["device"] if i == 0 else self.offload_device
|
|
|
|
|
device = cache.layer_init_kwargs["device"] if i == 0 else self.offload_device
|
|
|
|
|
layer.keys = layer.keys.to(device)
|
|
|
|
|
layer.values = layer.values.to(device)
|
|
|
|
|
self.original_device.append(cache.layer_init_args["device"])
|
|
|
|
|
self.original_device.append(cache.layer_init_kwargs["device"])
|
|
|
|
|
if len(cache) != cache.num_hidden_layers:
|
|
|
|
|
raise ValueError("If static layers are used, all cache layers must be initialized")
|
|
|
|
|
|
|
|
|
|
@@ -1030,15 +1027,15 @@ class Cache:
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
layer_classes (`type[CacheLayerMixin]` or `list[type[CacheLayerMixin]]`):
|
|
|
|
|
A list of `CacheLayerMixin` classes to instantiate for the cache. If only a `CacheLayerMixin` class is
|
|
|
|
|
provided, then it is used for all layers.
|
|
|
|
|
config (`PretrainedConfig`, *optional*):
|
|
|
|
|
Model configuration used to infer number of layers, head sizes, default
|
|
|
|
|
device/dtype, etc.
|
|
|
|
|
cache_processor (`CacheProcessor` or `str`, *optional*):
|
|
|
|
|
Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized")
|
|
|
|
|
or a CacheProcessor class.
|
|
|
|
|
layer_classes (`list[type[CacheLayerMixin]]`, *optional*):
|
|
|
|
|
List of `CacheLayerMixin` classes to instantiate for the cache. When shorter than the
|
|
|
|
|
required number of layers the list is cycled. Default is [DynamicLayer].
|
|
|
|
|
max_batch_size (`int`, *optional*): Maximum batch size for static caches.
|
|
|
|
|
max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are
|
|
|
|
|
clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`.
|
|
|
|
|
@@ -1053,9 +1050,9 @@ class Cache:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
layer_classes: Union[list[type[CacheLayerMixin]], type[CacheLayerMixin]],
|
|
|
|
|
config: Optional[PretrainedConfig] = None,
|
|
|
|
|
cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None,
|
|
|
|
|
layer_classes: Optional[list[type["CacheLayerMixin"]]] = None,
|
|
|
|
|
cache_processor: Optional[Union[str, type[CacheProcessor]]] = None,
|
|
|
|
|
max_batch_size: Optional[int] = None,
|
|
|
|
|
max_cache_len: Optional[int] = None,
|
|
|
|
|
device: Union[torch.device, str, None] = None,
|
|
|
|
|
@@ -1064,13 +1061,10 @@ class Cache:
|
|
|
|
|
tp_size: Optional[int] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
self.layers: list["CacheLayerMixin"] = []
|
|
|
|
|
processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor
|
|
|
|
|
|
|
|
|
|
if layer_classes is None:
|
|
|
|
|
layer_classes = [DynamicLayer]
|
|
|
|
|
|
|
|
|
|
self.layers: list[CacheLayerMixin] = []
|
|
|
|
|
self.layer_classes = layer_classes
|
|
|
|
|
|
|
|
|
|
processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor
|
|
|
|
|
kwargs.update(
|
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
max_cache_len=max_cache_len,
|
|
|
|
|
@@ -1080,7 +1074,8 @@ class Cache:
|
|
|
|
|
tp_size=tp_size,
|
|
|
|
|
)
|
|
|
|
|
processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs)
|
|
|
|
|
self.layer_init_args = parse_layer_args_from_model_config(config, **kwargs)
|
|
|
|
|
|
|
|
|
|
self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs)
|
|
|
|
|
self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
|
|
|
|
|
|
|
|
self.append_new_layers(self.num_hidden_layers - 1)
|
|
|
|
|
@@ -1138,10 +1133,14 @@ class Cache:
|
|
|
|
|
The index of the layer to append.
|
|
|
|
|
"""
|
|
|
|
|
while len(self.layers) <= layer_idx:
|
|
|
|
|
args = self.layer_init_args.copy()
|
|
|
|
|
if self.layer_init_args.get("layer_device_map", None) is not None:
|
|
|
|
|
args["device"] = args.pop("layer_device_map")[layer_idx]
|
|
|
|
|
new_layer = self.layer_classes[len(self.layers) % len(self.layer_classes)](**args)
|
|
|
|
|
kwargs = self.layer_init_kwargs.copy()
|
|
|
|
|
if self.layer_init_kwargs.get("layer_device_map", None) is not None:
|
|
|
|
|
kwargs["device"] = kwargs.pop("layer_device_map")[layer_idx]
|
|
|
|
|
|
|
|
|
|
new_layer_class = (
|
|
|
|
|
self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes
|
|
|
|
|
)
|
|
|
|
|
new_layer = new_layer_class(**kwargs)
|
|
|
|
|
self.layers.append(new_layer)
|
|
|
|
|
|
|
|
|
|
@apply_processors
|
|
|
|
|
@@ -1294,6 +1293,7 @@ class DynamicCache(Cache):
|
|
|
|
|
|
|
|
|
|
# Specialized constructor for DDP cache data, needed for BC
|
|
|
|
|
def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
|
|
|
|
|
super().__init__(layer_classes=DynamicLayer, *args, **kwargs)
|
|
|
|
|
# `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212
|
|
|
|
|
# and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
|
|
|
|
|
# iterable contains the key and value states for a layer gathered across replicas by torch.distributed
|
|
|
|
|
@@ -1303,7 +1303,6 @@ class DynamicCache(Cache):
|
|
|
|
|
if ddp_cache_data is not None:
|
|
|
|
|
for key_states, value_states in ddp_cache_data:
|
|
|
|
|
self.layers.append(DynamicLayer.from_tensors(key_states, value_states))
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
|
|
|
|
|
"""
|
|
|
|
|
@@ -1390,9 +1389,9 @@ class OffloadedCache(DynamicCache):
|
|
|
|
|
ensure the eviction is scheduled after all computations on that cache are finished.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Optional[PretrainedConfig] = None) -> None:
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
# Create the underlying cache with offload processor
|
|
|
|
|
super().__init__(cache_processor=OffloadedCacheProcessor, config=config)
|
|
|
|
|
super().__init__(cache_processor=OffloadedCacheProcessor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StaticCache(Cache):
|
|
|
|
|
@@ -1422,44 +1421,45 @@ class StaticCache(Cache):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(layer_classes=[StaticLayer], *args, **kwargs)
|
|
|
|
|
super().__init__(layer_classes=StaticLayer, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridCache(Cache):
|
|
|
|
|
class OffloadedStaticCache(StaticCache):
|
|
|
|
|
"""
|
|
|
|
|
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
|
|
|
|
|
attention and global attention in every other layer (originally implemented for Gemma2).
|
|
|
|
|
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
|
|
|
|
|
for global attention. For more information, see the documentation of those layer types.
|
|
|
|
|
A drop-in replacement for StaticCache that conserves accelerator memory by offloading
|
|
|
|
|
cache tensors to CPU when not actively being used.
|
|
|
|
|
|
|
|
|
|
This cache maintains the compilation-friendly properties of StaticCache while enabling
|
|
|
|
|
much longer sequences by offloading inactive layers to CPU memory.
|
|
|
|
|
|
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
|
|
|
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
|
|
|
|
|
|
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
|
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
|
|
|
|
|
|
|
|
|
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
|
|
|
|
|
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
|
>>> # Prepare a cache class with offloading
|
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
|
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
|
>>> past_key_values = OffloadedStaticCache(
|
|
|
|
|
... config=model.config,
|
|
|
|
|
... max_batch_size=1,
|
|
|
|
|
... max_cache_len=max_generated_length,
|
|
|
|
|
... device=model.device,
|
|
|
|
|
... dtype=model.dtype
|
|
|
|
|
... )
|
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
|
HybridCache()
|
|
|
|
|
>>> outputs.past_key_values # access cache with offloaded layers
|
|
|
|
|
OffloadedStaticCache()
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: PretrainedConfig, *args, **kwargs):
|
|
|
|
|
if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None:
|
|
|
|
|
layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types]
|
|
|
|
|
else:
|
|
|
|
|
layer_classes = [StaticLayer]
|
|
|
|
|
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
|
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
|
|
|
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SlidingWindowCache(Cache):
|
|
|
|
|
@@ -1502,7 +1502,64 @@ class SlidingWindowCache(Cache):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs)
|
|
|
|
|
super().__init__(layer_classes=SlidingWindowLayer, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridCache(Cache):
|
|
|
|
|
"""
|
|
|
|
|
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
|
|
|
|
|
attention and global attention in every other layer (originally implemented for Gemma2).
|
|
|
|
|
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
|
|
|
|
|
for global attention. For more information, see the documentation of those layer types.
|
|
|
|
|
|
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
|
|
|
|
|
|
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
|
|
|
|
|
|
|
|
|
|
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
|
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
|
HybridCache()
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: PretrainedConfig, *args, **kwargs):
|
|
|
|
|
if hasattr(config, "layer_types"):
|
|
|
|
|
layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types]
|
|
|
|
|
else:
|
|
|
|
|
# In this case, fall back to StaticCache
|
|
|
|
|
layer_classes = [StaticLayer] * config.num_hidden_layers
|
|
|
|
|
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# The mapping already handles dispatching the correct layers in Hybrid, this is only used for BC
|
|
|
|
|
class HybridChunkedCache(HybridCache): ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OffloadedHybridCache(HybridChunkedCache):
|
|
|
|
|
"""
|
|
|
|
|
A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading
|
|
|
|
|
cache tensors to CPU when not actively being used.
|
|
|
|
|
|
|
|
|
|
This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling
|
|
|
|
|
much longer sequences by offloading inactive layers to CPU memory.
|
|
|
|
|
|
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
|
|
|
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuantizedCache(DynamicCache):
|
|
|
|
|
@@ -1615,100 +1672,6 @@ class HQQQuantizedCache(QuantizedCache):
|
|
|
|
|
Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OffloadedStaticCache(StaticCache):
|
|
|
|
|
"""
|
|
|
|
|
A drop-in replacement for StaticCache that conserves accelerator memory by offloading
|
|
|
|
|
cache tensors to CPU when not actively being used.
|
|
|
|
|
|
|
|
|
|
This cache maintains the compilation-friendly properties of StaticCache while enabling
|
|
|
|
|
much longer sequences by offloading inactive layers to CPU memory.
|
|
|
|
|
|
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
```python
|
|
|
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
|
|
|
|
|
|
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
|
|
|
|
|
|
|
|
|
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
>>> # Prepare a cache class with offloading
|
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
|
>>> past_key_values = OffloadedStaticCache(
|
|
|
|
|
... config=model.config,
|
|
|
|
|
... max_batch_size=1,
|
|
|
|
|
... max_cache_len=max_generated_length,
|
|
|
|
|
... device=model.device,
|
|
|
|
|
... dtype=model.dtype
|
|
|
|
|
... )
|
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
|
>>> outputs.past_key_values # access cache with offloaded layers
|
|
|
|
|
OffloadedStaticCache()
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
|
|
|
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridChunkedCache(Cache):
|
|
|
|
|
"""
|
|
|
|
|
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
|
|
|
|
|
attention and global attention in every other layer, with support for prefill chunking (originally implemented
|
|
|
|
|
for Llama4).
|
|
|
|
|
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
|
|
|
|
|
for global attention. For more information, see the documentation of each subcomponent cache class.
|
|
|
|
|
|
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
|
|
|
|
|
|
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
|
|
|
|
|
|
|
|
|
|
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
|
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
|
HybridCache()
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: PretrainedConfig, *args, **kwargs):
|
|
|
|
|
hybrid_map = LAYER_CLASS_MAP.copy()
|
|
|
|
|
hybrid_map["sliding_attention"] = ChunkedSlidingLayer
|
|
|
|
|
hybrid_map["chunked_attention"] = ChunkedSlidingLayer
|
|
|
|
|
if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None:
|
|
|
|
|
layer_classes = [hybrid_map[layer_type] for layer_type in config.layer_types]
|
|
|
|
|
else:
|
|
|
|
|
layer_classes = [StaticLayer]
|
|
|
|
|
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OffloadedHybridCache(HybridChunkedCache):
|
|
|
|
|
"""
|
|
|
|
|
A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading
|
|
|
|
|
cache tensors to CPU when not actively being used.
|
|
|
|
|
|
|
|
|
|
This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling
|
|
|
|
|
much longer sequences by offloading inactive layers to CPU memory.
|
|
|
|
|
|
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
|
|
|
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EncoderDecoderCache(Cache):
|
|
|
|
|
"""
|
|
|
|
|
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
|
|
|
|
|
@@ -1741,7 +1704,7 @@ class EncoderDecoderCache(Cache):
|
|
|
|
|
is_compileable = None
|
|
|
|
|
|
|
|
|
|
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
|
|
|
|
|
super().__init__()
|
|
|
|
|
super().__init__(layer_classes=DynamicLayer)
|
|
|
|
|
self.self_attention_cache = self_attention_cache
|
|
|
|
|
self.cross_attention_cache = cross_attention_cache
|
|
|
|
|
self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False)
|
|
|
|
|
@@ -1998,7 +1961,7 @@ def parse_layer_args_from_model_config(
|
|
|
|
|
LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = {
|
|
|
|
|
"full_attention": StaticLayer,
|
|
|
|
|
"sliding_attention": SlidingWindowLayer,
|
|
|
|
|
"chunked_attention": SlidingWindowLayer,
|
|
|
|
|
"chunked_attention": ChunkedSlidingLayer,
|
|
|
|
|
}
|
|
|
|
|
PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = {
|
|
|
|
|
"offloaded": OffloadedCacheProcessor,
|
|
|
|
|
|