Fix DynamicCache and simplify Cache classes a bit (#39590)

* fix

* use kwargs

* simplify

* Update cache_utils.py

* Update cache_utils.py

* Update test_cache_utils.py

* fix

* style
This commit is contained in:
Cyril Vallez
2025-07-23 10:13:45 +02:00
committed by GitHub
parent d9b35c635e
commit 5dba4bc7b2
6 changed files with 147 additions and 184 deletions

View File

@@ -4,6 +4,7 @@ import importlib.metadata
import inspect import inspect
import json import json
import os import os
from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
@@ -24,7 +25,7 @@ if is_hqq_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class CacheLayerMixin: class CacheLayerMixin(ABC):
"""Base, abstract class for a single layer's cache.""" """Base, abstract class for a single layer's cache."""
is_compileable = False is_compileable = False
@@ -32,26 +33,22 @@ class CacheLayerMixin:
def __init__(self): def __init__(self):
self.keys, self.values = None, None self.keys, self.values = None, None
@abstractmethod
def update( def update(
self, self,
key_states: torch.Tensor, key_states: torch.Tensor,
value_states: torch.Tensor, value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None, cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]: ...
"""Updates KV cache, returns updated keys/values of the layer."""
raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.")
def get_seq_length(self, cache_position=None) -> int: @abstractmethod
"""Returns the sequence length of this layer's cache.""" def get_seq_length(self, cache_position=None) -> int: ...
raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.")
def get_max_cache_shape(self) -> int: @abstractmethod
"""Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" def get_max_cache_shape(self) -> int: ...
raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.")
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: @abstractmethod
"""Returns mask sizes for the layer.""" def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.")
def reset(self) -> None: def reset(self) -> None:
"""Resets the cache values while preserving the objects""" """Resets the cache values while preserving the objects"""
@@ -76,26 +73,6 @@ class DynamicLayer(CacheLayerMixin):
See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
""" """
@classmethod
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer":
"""
Build a `DynamicLayer` instance from pre-existing key/value tensors.
Args:
keys (`torch.Tensor`):
Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
values (`torch.Tensor`):
Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
Returns:
`DynamicLayer`: The newly constructed layer whose internal cache directly references
the supplied tensors.
"""
layer = cls()
layer.keys = keys
layer.values = values
return layer
def update( def update(
self, self,
key_states: torch.Tensor, key_states: torch.Tensor,
@@ -175,6 +152,26 @@ class DynamicLayer(CacheLayerMixin):
kv_length = query_length + past_seen_tokens kv_length = query_length + past_seen_tokens
return kv_length, kv_offset return kv_length, kv_offset
@classmethod
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer":
"""
Build a `DynamicLayer` instance from pre-existing key/value tensors.
Args:
keys (`torch.Tensor`):
Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
values (`torch.Tensor`):
Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
Returns:
`DynamicLayer`: The newly constructed layer whose internal cache directly references
the supplied tensors.
"""
layer = cls()
layer.keys = keys
layer.values = values
return layer
class StaticLayer(CacheLayerMixin): class StaticLayer(CacheLayerMixin):
""" """
@@ -558,10 +555,10 @@ class OffloadedCacheProcessor(CacheProcessor):
self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers)
if self.is_static: if self.is_static:
for i, layer in enumerate(cache.layers): for i, layer in enumerate(cache.layers):
device = cache.layer_init_args["device"] if i == 0 else self.offload_device device = cache.layer_init_kwargs["device"] if i == 0 else self.offload_device
layer.keys = layer.keys.to(device) layer.keys = layer.keys.to(device)
layer.values = layer.values.to(device) layer.values = layer.values.to(device)
self.original_device.append(cache.layer_init_args["device"]) self.original_device.append(cache.layer_init_kwargs["device"])
if len(cache) != cache.num_hidden_layers: if len(cache) != cache.num_hidden_layers:
raise ValueError("If static layers are used, all cache layers must be initialized") raise ValueError("If static layers are used, all cache layers must be initialized")
@@ -1030,15 +1027,15 @@ class Cache:
``` ```
Parameters: Parameters:
layer_classes (`type[CacheLayerMixin]` or `list[type[CacheLayerMixin]]`):
A list of `CacheLayerMixin` classes to instantiate for the cache. If only a `CacheLayerMixin` class is
provided, then it is used for all layers.
config (`PretrainedConfig`, *optional*): config (`PretrainedConfig`, *optional*):
Model configuration used to infer number of layers, head sizes, default Model configuration used to infer number of layers, head sizes, default
device/dtype, etc. device/dtype, etc.
cache_processor (`CacheProcessor` or `str`, *optional*): cache_processor (`CacheProcessor` or `str`, *optional*):
Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized")
or a CacheProcessor class. or a CacheProcessor class.
layer_classes (`list[type[CacheLayerMixin]]`, *optional*):
List of `CacheLayerMixin` classes to instantiate for the cache. When shorter than the
required number of layers the list is cycled. Default is [DynamicLayer].
max_batch_size (`int`, *optional*): Maximum batch size for static caches. max_batch_size (`int`, *optional*): Maximum batch size for static caches.
max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are
clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`. clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`.
@@ -1053,9 +1050,9 @@ class Cache:
def __init__( def __init__(
self, self,
layer_classes: Union[list[type[CacheLayerMixin]], type[CacheLayerMixin]],
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None, cache_processor: Optional[Union[str, type[CacheProcessor]]] = None,
layer_classes: Optional[list[type["CacheLayerMixin"]]] = None,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
max_cache_len: Optional[int] = None, max_cache_len: Optional[int] = None,
device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None,
@@ -1064,13 +1061,10 @@ class Cache:
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
**kwargs, **kwargs,
): ):
self.layers: list["CacheLayerMixin"] = [] self.layers: list[CacheLayerMixin] = []
processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor
if layer_classes is None:
layer_classes = [DynamicLayer]
self.layer_classes = layer_classes self.layer_classes = layer_classes
processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor
kwargs.update( kwargs.update(
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
@@ -1080,7 +1074,8 @@ class Cache:
tp_size=tp_size, tp_size=tp_size,
) )
processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs)
self.layer_init_args = parse_layer_args_from_model_config(config, **kwargs)
self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs)
self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
self.append_new_layers(self.num_hidden_layers - 1) self.append_new_layers(self.num_hidden_layers - 1)
@@ -1138,10 +1133,14 @@ class Cache:
The index of the layer to append. The index of the layer to append.
""" """
while len(self.layers) <= layer_idx: while len(self.layers) <= layer_idx:
args = self.layer_init_args.copy() kwargs = self.layer_init_kwargs.copy()
if self.layer_init_args.get("layer_device_map", None) is not None: if self.layer_init_kwargs.get("layer_device_map", None) is not None:
args["device"] = args.pop("layer_device_map")[layer_idx] kwargs["device"] = kwargs.pop("layer_device_map")[layer_idx]
new_layer = self.layer_classes[len(self.layers) % len(self.layer_classes)](**args)
new_layer_class = (
self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes
)
new_layer = new_layer_class(**kwargs)
self.layers.append(new_layer) self.layers.append(new_layer)
@apply_processors @apply_processors
@@ -1294,6 +1293,7 @@ class DynamicCache(Cache):
# Specialized constructor for DDP cache data, needed for BC # Specialized constructor for DDP cache data, needed for BC
def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
super().__init__(layer_classes=DynamicLayer, *args, **kwargs)
# `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212
# and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
# iterable contains the key and value states for a layer gathered across replicas by torch.distributed # iterable contains the key and value states for a layer gathered across replicas by torch.distributed
@@ -1303,7 +1303,6 @@ class DynamicCache(Cache):
if ddp_cache_data is not None: if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data: for key_states, value_states in ddp_cache_data:
self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) self.layers.append(DynamicLayer.from_tensors(key_states, value_states))
super().__init__(*args, **kwargs)
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
""" """
@@ -1390,9 +1389,9 @@ class OffloadedCache(DynamicCache):
ensure the eviction is scheduled after all computations on that cache are finished. ensure the eviction is scheduled after all computations on that cache are finished.
""" """
def __init__(self, config: Optional[PretrainedConfig] = None) -> None: def __init__(self) -> None:
# Create the underlying cache with offload processor # Create the underlying cache with offload processor
super().__init__(cache_processor=OffloadedCacheProcessor, config=config) super().__init__(cache_processor=OffloadedCacheProcessor)
class StaticCache(Cache): class StaticCache(Cache):
@@ -1422,44 +1421,45 @@ class StaticCache(Cache):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(layer_classes=[StaticLayer], *args, **kwargs) super().__init__(layer_classes=StaticLayer, *args, **kwargs)
class HybridCache(Cache): class OffloadedStaticCache(StaticCache):
""" """
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window A drop-in replacement for StaticCache that conserves accelerator memory by offloading
attention and global attention in every other layer (originally implemented for Gemma2). cache tensors to CPU when not actively being used.
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
for global attention. For more information, see the documentation of those layer types. This cache maintains the compilation-friendly properties of StaticCache while enabling
much longer sequences by offloading inactive layers to CPU memory.
See `Cache` for details on common methods that are implemented by all cache classes. See `Cache` for details on common methods that are implemented by all cache classes.
Example: Example:
```python ```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class with offloading
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> past_key_values = OffloadedStaticCache(
... config=model.config,
... max_batch_size=1,
... max_cache_len=max_generated_length,
... device=model.device,
... dtype=model.dtype
... )
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation >>> outputs.past_key_values # access cache with offloaded layers
HybridCache() OffloadedStaticCache()
``` ```
""" """
def __init__(self, config: PretrainedConfig, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None: super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types]
else:
layer_classes = [StaticLayer]
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
class SlidingWindowCache(Cache): class SlidingWindowCache(Cache):
@@ -1502,7 +1502,64 @@ class SlidingWindowCache(Cache):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) super().__init__(layer_classes=SlidingWindowLayer, *args, **kwargs)
class HybridCache(Cache):
"""
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
attention and global attention in every other layer (originally implemented for Gemma2).
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
for global attention. For more information, see the documentation of those layer types.
See `Cache` for details on common methods that are implemented by all cache classes.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
```
"""
def __init__(self, config: PretrainedConfig, *args, **kwargs):
if hasattr(config, "layer_types"):
layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types]
else:
# In this case, fall back to StaticCache
layer_classes = [StaticLayer] * config.num_hidden_layers
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
# The mapping already handles dispatching the correct layers in Hybrid, this is only used for BC
class HybridChunkedCache(HybridCache): ...
class OffloadedHybridCache(HybridChunkedCache):
"""
A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading
cache tensors to CPU when not actively being used.
This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling
much longer sequences by offloading inactive layers to CPU memory.
See `Cache` for details on common methods that are implemented by all cache classes.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
class QuantizedCache(DynamicCache): class QuantizedCache(DynamicCache):
@@ -1615,100 +1672,6 @@ class HQQQuantizedCache(QuantizedCache):
Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)
class OffloadedStaticCache(StaticCache):
"""
A drop-in replacement for StaticCache that conserves accelerator memory by offloading
cache tensors to CPU when not actively being used.
This cache maintains the compilation-friendly properties of StaticCache while enabling
much longer sequences by offloading inactive layers to CPU memory.
See `Cache` for details on common methods that are implemented by all cache classes.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class with offloading
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = OffloadedStaticCache(
... config=model.config,
... max_batch_size=1,
... max_cache_len=max_generated_length,
... device=model.device,
... dtype=model.dtype
... )
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache with offloaded layers
OffloadedStaticCache()
```
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
class HybridChunkedCache(Cache):
"""
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
attention and global attention in every other layer, with support for prefill chunking (originally implemented
for Llama4).
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
for global attention. For more information, see the documentation of each subcomponent cache class.
See `Cache` for details on common methods that are implemented by all cache classes.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
```
"""
def __init__(self, config: PretrainedConfig, *args, **kwargs):
hybrid_map = LAYER_CLASS_MAP.copy()
hybrid_map["sliding_attention"] = ChunkedSlidingLayer
hybrid_map["chunked_attention"] = ChunkedSlidingLayer
if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None:
layer_classes = [hybrid_map[layer_type] for layer_type in config.layer_types]
else:
layer_classes = [StaticLayer]
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
class OffloadedHybridCache(HybridChunkedCache):
"""
A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading
cache tensors to CPU when not actively being used.
This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling
much longer sequences by offloading inactive layers to CPU memory.
See `Cache` for details on common methods that are implemented by all cache classes.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
class EncoderDecoderCache(Cache): class EncoderDecoderCache(Cache):
""" """
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
@@ -1741,7 +1704,7 @@ class EncoderDecoderCache(Cache):
is_compileable = None is_compileable = None
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
super().__init__() super().__init__(layer_classes=DynamicLayer)
self.self_attention_cache = self_attention_cache self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache self.cross_attention_cache = cross_attention_cache
self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False)
@@ -1998,7 +1961,7 @@ def parse_layer_args_from_model_config(
LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = {
"full_attention": StaticLayer, "full_attention": StaticLayer,
"sliding_attention": SlidingWindowLayer, "sliding_attention": SlidingWindowLayer,
"chunked_attention": SlidingWindowLayer, "chunked_attention": ChunkedSlidingLayer,
} }
PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = {
"offloaded": OffloadedCacheProcessor, "offloaded": OffloadedCacheProcessor,

View File

@@ -31,7 +31,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache, DynamicLayer
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
@@ -104,7 +104,7 @@ class HybridMambaAttentionDynamicCache(Cache):
is_compileable = False is_compileable = False
def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
super().__init__() super().__init__(layer_classes=DynamicLayer)
self.layers_block_type = config.layers_block_type self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba self.has_previous_state = False # only used by mamba
conv_kernel_size = config.mamba_d_conv conv_kernel_size = config.mamba_d_conv

View File

@@ -42,7 +42,7 @@ from transformers.models.mamba2.modeling_mamba2 import (
segment_sum, segment_sum,
) )
from ...cache_utils import Cache from ...cache_utils import DynamicLayer
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
@@ -99,7 +99,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False):
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer # Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache, Cache): class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache):
""" """
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len). (which has a constant shape regardless of seq_len).
@@ -114,7 +114,7 @@ class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache, Cache):
""" """
def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
Cache.__init__() HybridMambaAttentionDynamicCache.__init__(layer_classes=DynamicLayer)
self.layers_block_type = config.layers_block_type self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba self.has_previous_state = False # only used by mamba
conv_kernel_size = config.mamba_d_conv conv_kernel_size = config.mamba_d_conv

View File

@@ -27,7 +27,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache, DynamicLayer
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
@@ -240,7 +240,7 @@ class HybridMambaAttentionDynamicCache(Cache):
is_compileable = False is_compileable = False
def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None):
super().__init__() super().__init__(layer_classes=DynamicLayer)
self.layers_block_type = config.layers_block_type self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba self.has_previous_state = False # only used by mamba
conv_kernel_size = config.mamba_d_conv conv_kernel_size = config.mamba_d_conv

View File

@@ -28,7 +28,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache, DynamicLayer
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
@@ -202,7 +202,7 @@ class HybridMambaAttentionDynamicCache(Cache):
is_compileable = False is_compileable = False
def __init__(self, config, batch_size, dtype=torch.float16, device=None): def __init__(self, config, batch_size, dtype=torch.float16, device=None):
super().__init__() super().__init__(layer_classes=DynamicLayer)
self.dtype = dtype self.dtype = dtype
self.layers_block_type = config.layers_block_type self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba self.has_previous_state = False # only used by mamba

View File

@@ -1307,7 +1307,7 @@ class SyntheticCacheTest(unittest.TestCase):
config = copy.deepcopy(self.config) config = copy.deepcopy(self.config)
config.num_hidden_layers = 2 config.num_hidden_layers = 2
config.layer_types = ["full_attention", "sliding_attention"] config.layer_types = ["full_attention", "chunked_attention"]
config.sliding_window = 2 config.sliding_window = 2
max_cache_len = 4 max_cache_len = 4
chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len) chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len)
@@ -1387,7 +1387,7 @@ class SyntheticCacheTest(unittest.TestCase):
config = copy.deepcopy(self.config) config = copy.deepcopy(self.config)
config.num_hidden_layers = 1 config.num_hidden_layers = 1
config.layer_types = ["sliding_attention"] config.layer_types = ["chunked_attention"]
config.sliding_window = 3 config.sliding_window = 3
cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3) cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3)