Fix DynamicCache and simplify Cache classes a bit (#39590)

* fix

* use kwargs

* simplify

* Update cache_utils.py

* Update cache_utils.py

* Update test_cache_utils.py

* fix

* style
This commit is contained in:
Cyril Vallez
2025-07-23 10:13:45 +02:00
committed by GitHub
parent d9b35c635e
commit 5dba4bc7b2
6 changed files with 147 additions and 184 deletions

View File

@@ -4,6 +4,7 @@ import importlib.metadata
import inspect
import json
import os
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
@@ -24,7 +25,7 @@ if is_hqq_available():
logger = logging.get_logger(__name__)
class CacheLayerMixin:
class CacheLayerMixin(ABC):
"""Base, abstract class for a single layer's cache."""
is_compileable = False
@@ -32,26 +33,22 @@ class CacheLayerMixin:
def __init__(self):
self.keys, self.values = None, None
@abstractmethod
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Updates KV cache, returns updated keys/values of the layer."""
raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.")
) -> tuple[torch.Tensor, torch.Tensor]: ...
def get_seq_length(self, cache_position=None) -> int:
"""Returns the sequence length of this layer's cache."""
raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.")
@abstractmethod
def get_seq_length(self, cache_position=None) -> int: ...
def get_max_cache_shape(self) -> int:
"""Returns the maximum sequence length (i.e. max capacity) of this layer's cache."""
raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.")
@abstractmethod
def get_max_cache_shape(self) -> int: ...
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Returns mask sizes for the layer."""
raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.")
@abstractmethod
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
def reset(self) -> None:
"""Resets the cache values while preserving the objects"""
@@ -76,26 +73,6 @@ class DynamicLayer(CacheLayerMixin):
See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
"""
@classmethod
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer":
"""
Build a `DynamicLayer` instance from pre-existing key/value tensors.
Args:
keys (`torch.Tensor`):
Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
values (`torch.Tensor`):
Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
Returns:
`DynamicLayer`: The newly constructed layer whose internal cache directly references
the supplied tensors.
"""
layer = cls()
layer.keys = keys
layer.values = values
return layer
def update(
self,
key_states: torch.Tensor,
@@ -175,6 +152,26 @@ class DynamicLayer(CacheLayerMixin):
kv_length = query_length + past_seen_tokens
return kv_length, kv_offset
@classmethod
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer":
"""
Build a `DynamicLayer` instance from pre-existing key/value tensors.
Args:
keys (`torch.Tensor`):
Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
values (`torch.Tensor`):
Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
Returns:
`DynamicLayer`: The newly constructed layer whose internal cache directly references
the supplied tensors.
"""
layer = cls()
layer.keys = keys
layer.values = values
return layer
class StaticLayer(CacheLayerMixin):
"""
@@ -558,10 +555,10 @@ class OffloadedCacheProcessor(CacheProcessor):
self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers)
if self.is_static:
for i, layer in enumerate(cache.layers):
device = cache.layer_init_args["device"] if i == 0 else self.offload_device
device = cache.layer_init_kwargs["device"] if i == 0 else self.offload_device
layer.keys = layer.keys.to(device)
layer.values = layer.values.to(device)
self.original_device.append(cache.layer_init_args["device"])
self.original_device.append(cache.layer_init_kwargs["device"])
if len(cache) != cache.num_hidden_layers:
raise ValueError("If static layers are used, all cache layers must be initialized")
@@ -1030,15 +1027,15 @@ class Cache:
```
Parameters:
layer_classes (`type[CacheLayerMixin]` or `list[type[CacheLayerMixin]]`):
A list of `CacheLayerMixin` classes to instantiate for the cache. If only a `CacheLayerMixin` class is
provided, then it is used for all layers.
config (`PretrainedConfig`, *optional*):
Model configuration used to infer number of layers, head sizes, default
device/dtype, etc.
cache_processor (`CacheProcessor` or `str`, *optional*):
Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized")
or a CacheProcessor class.
layer_classes (`list[type[CacheLayerMixin]]`, *optional*):
List of `CacheLayerMixin` classes to instantiate for the cache. When shorter than the
required number of layers the list is cycled. Default is [DynamicLayer].
max_batch_size (`int`, *optional*): Maximum batch size for static caches.
max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are
clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`.
@@ -1053,9 +1050,9 @@ class Cache:
def __init__(
self,
layer_classes: Union[list[type[CacheLayerMixin]], type[CacheLayerMixin]],
config: Optional[PretrainedConfig] = None,
cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None,
layer_classes: Optional[list[type["CacheLayerMixin"]]] = None,
cache_processor: Optional[Union[str, type[CacheProcessor]]] = None,
max_batch_size: Optional[int] = None,
max_cache_len: Optional[int] = None,
device: Union[torch.device, str, None] = None,
@@ -1064,13 +1061,10 @@ class Cache:
tp_size: Optional[int] = None,
**kwargs,
):
self.layers: list["CacheLayerMixin"] = []
processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor
if layer_classes is None:
layer_classes = [DynamicLayer]
self.layers: list[CacheLayerMixin] = []
self.layer_classes = layer_classes
processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor
kwargs.update(
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
@@ -1080,7 +1074,8 @@ class Cache:
tp_size=tp_size,
)
processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs)
self.layer_init_args = parse_layer_args_from_model_config(config, **kwargs)
self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs)
self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
self.append_new_layers(self.num_hidden_layers - 1)
@@ -1138,10 +1133,14 @@ class Cache:
The index of the layer to append.
"""
while len(self.layers) <= layer_idx:
args = self.layer_init_args.copy()
if self.layer_init_args.get("layer_device_map", None) is not None:
args["device"] = args.pop("layer_device_map")[layer_idx]
new_layer = self.layer_classes[len(self.layers) % len(self.layer_classes)](**args)
kwargs = self.layer_init_kwargs.copy()
if self.layer_init_kwargs.get("layer_device_map", None) is not None:
kwargs["device"] = kwargs.pop("layer_device_map")[layer_idx]
new_layer_class = (
self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes
)
new_layer = new_layer_class(**kwargs)
self.layers.append(new_layer)
@apply_processors
@@ -1294,6 +1293,7 @@ class DynamicCache(Cache):
# Specialized constructor for DDP cache data, needed for BC
def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
super().__init__(layer_classes=DynamicLayer, *args, **kwargs)
# `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212
# and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
# iterable contains the key and value states for a layer gathered across replicas by torch.distributed
@@ -1303,7 +1303,6 @@ class DynamicCache(Cache):
if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(DynamicLayer.from_tensors(key_states, value_states))
super().__init__(*args, **kwargs)
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
"""
@@ -1390,9 +1389,9 @@ class OffloadedCache(DynamicCache):
ensure the eviction is scheduled after all computations on that cache are finished.
"""
def __init__(self, config: Optional[PretrainedConfig] = None) -> None:
def __init__(self) -> None:
# Create the underlying cache with offload processor
super().__init__(cache_processor=OffloadedCacheProcessor, config=config)
super().__init__(cache_processor=OffloadedCacheProcessor)
class StaticCache(Cache):
@@ -1422,44 +1421,45 @@ class StaticCache(Cache):
"""
def __init__(self, *args, **kwargs):
super().__init__(layer_classes=[StaticLayer], *args, **kwargs)
super().__init__(layer_classes=StaticLayer, *args, **kwargs)
class HybridCache(Cache):
class OffloadedStaticCache(StaticCache):
"""
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
attention and global attention in every other layer (originally implemented for Gemma2).
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
for global attention. For more information, see the documentation of those layer types.
A drop-in replacement for StaticCache that conserves accelerator memory by offloading
cache tensors to CPU when not actively being used.
This cache maintains the compilation-friendly properties of StaticCache while enabling
much longer sequences by offloading inactive layers to CPU memory.
See `Cache` for details on common methods that are implemented by all cache classes.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> # Prepare a cache class with offloading
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = OffloadedStaticCache(
... config=model.config,
... max_batch_size=1,
... max_cache_len=max_generated_length,
... device=model.device,
... dtype=model.dtype
... )
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
>>> outputs.past_key_values # access cache with offloaded layers
OffloadedStaticCache()
```
"""
def __init__(self, config: PretrainedConfig, *args, **kwargs):
if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None:
layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types]
else:
layer_classes = [StaticLayer]
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
class SlidingWindowCache(Cache):
@@ -1502,7 +1502,64 @@ class SlidingWindowCache(Cache):
"""
def __init__(self, *args, **kwargs):
super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs)
super().__init__(layer_classes=SlidingWindowLayer, *args, **kwargs)
class HybridCache(Cache):
"""
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
attention and global attention in every other layer (originally implemented for Gemma2).
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
for global attention. For more information, see the documentation of those layer types.
See `Cache` for details on common methods that are implemented by all cache classes.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
```
"""
def __init__(self, config: PretrainedConfig, *args, **kwargs):
if hasattr(config, "layer_types"):
layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types]
else:
# In this case, fall back to StaticCache
layer_classes = [StaticLayer] * config.num_hidden_layers
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
# The mapping already handles dispatching the correct layers in Hybrid, this is only used for BC
class HybridChunkedCache(HybridCache): ...
class OffloadedHybridCache(HybridChunkedCache):
"""
A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading
cache tensors to CPU when not actively being used.
This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling
much longer sequences by offloading inactive layers to CPU memory.
See `Cache` for details on common methods that are implemented by all cache classes.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
class QuantizedCache(DynamicCache):
@@ -1615,100 +1672,6 @@ class HQQQuantizedCache(QuantizedCache):
Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)
class OffloadedStaticCache(StaticCache):
"""
A drop-in replacement for StaticCache that conserves accelerator memory by offloading
cache tensors to CPU when not actively being used.
This cache maintains the compilation-friendly properties of StaticCache while enabling
much longer sequences by offloading inactive layers to CPU memory.
See `Cache` for details on common methods that are implemented by all cache classes.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class with offloading
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = OffloadedStaticCache(
... config=model.config,
... max_batch_size=1,
... max_cache_len=max_generated_length,
... device=model.device,
... dtype=model.dtype
... )
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache with offloaded layers
OffloadedStaticCache()
```
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
class HybridChunkedCache(Cache):
"""
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
attention and global attention in every other layer, with support for prefill chunking (originally implemented
for Llama4).
Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
for global attention. For more information, see the documentation of each subcomponent cache class.
See `Cache` for details on common methods that are implemented by all cache classes.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
```
"""
def __init__(self, config: PretrainedConfig, *args, **kwargs):
hybrid_map = LAYER_CLASS_MAP.copy()
hybrid_map["sliding_attention"] = ChunkedSlidingLayer
hybrid_map["chunked_attention"] = ChunkedSlidingLayer
if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None:
layer_classes = [hybrid_map[layer_type] for layer_type in config.layer_types]
else:
layer_classes = [StaticLayer]
super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
class OffloadedHybridCache(HybridChunkedCache):
"""
A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading
cache tensors to CPU when not actively being used.
This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling
much longer sequences by offloading inactive layers to CPU memory.
See `Cache` for details on common methods that are implemented by all cache classes.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)
class EncoderDecoderCache(Cache):
"""
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
@@ -1741,7 +1704,7 @@ class EncoderDecoderCache(Cache):
is_compileable = None
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
super().__init__()
super().__init__(layer_classes=DynamicLayer)
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache
self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False)
@@ -1998,7 +1961,7 @@ def parse_layer_args_from_model_config(
LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = {
"full_attention": StaticLayer,
"sliding_attention": SlidingWindowLayer,
"chunked_attention": SlidingWindowLayer,
"chunked_attention": ChunkedSlidingLayer,
}
PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = {
"offloaded": OffloadedCacheProcessor,

View File

@@ -31,7 +31,7 @@ from torch import nn
from transformers.activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...cache_utils import Cache, DynamicCache, DynamicLayer
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
@@ -104,7 +104,7 @@ class HybridMambaAttentionDynamicCache(Cache):
is_compileable = False
def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
super().__init__()
super().__init__(layer_classes=DynamicLayer)
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
conv_kernel_size = config.mamba_d_conv

View File

@@ -42,7 +42,7 @@ from transformers.models.mamba2.modeling_mamba2 import (
segment_sum,
)
from ...cache_utils import Cache
from ...cache_utils import DynamicLayer
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
@@ -99,7 +99,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False):
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache, Cache):
class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache):
"""
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len).
@@ -114,7 +114,7 @@ class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache, Cache):
"""
def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
Cache.__init__()
HybridMambaAttentionDynamicCache.__init__(layer_classes=DynamicLayer)
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
conv_kernel_size = config.mamba_d_conv

View File

@@ -27,7 +27,7 @@ from torch import nn
from transformers.activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...cache_utils import Cache, DynamicCache, DynamicLayer
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_layers import GradientCheckpointingLayer
@@ -240,7 +240,7 @@ class HybridMambaAttentionDynamicCache(Cache):
is_compileable = False
def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None):
super().__init__()
super().__init__(layer_classes=DynamicLayer)
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
conv_kernel_size = config.mamba_d_conv

View File

@@ -28,7 +28,7 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...cache_utils import Cache, DynamicCache, DynamicLayer
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
@@ -202,7 +202,7 @@ class HybridMambaAttentionDynamicCache(Cache):
is_compileable = False
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
super().__init__()
super().__init__(layer_classes=DynamicLayer)
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba

View File

@@ -1307,7 +1307,7 @@ class SyntheticCacheTest(unittest.TestCase):
config = copy.deepcopy(self.config)
config.num_hidden_layers = 2
config.layer_types = ["full_attention", "sliding_attention"]
config.layer_types = ["full_attention", "chunked_attention"]
config.sliding_window = 2
max_cache_len = 4
chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len)
@@ -1387,7 +1387,7 @@ class SyntheticCacheTest(unittest.TestCase):
config = copy.deepcopy(self.config)
config.num_hidden_layers = 1
config.layer_types = ["sliding_attention"]
config.layer_types = ["chunked_attention"]
config.sliding_window = 3
cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3)