Add a static cache that offloads to the CPU or other device (#32161)
* Add a static cache that offloads to the CPU or other device * Fix PR comments, add unit-tests
This commit is contained in:
@@ -390,6 +390,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
- get_seq_length
|
||||
- reset
|
||||
|
||||
[[autodoc]] OffloadedStaticCache
|
||||
- update
|
||||
- get_seq_length
|
||||
- reset
|
||||
|
||||
[[autodoc]] HybridCache
|
||||
- update
|
||||
- get_seq_length
|
||||
|
||||
@@ -97,11 +97,12 @@ with the [`~DynamicCache`] class being the default cache for most models. It all
|
||||
Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case.
|
||||
|
||||
| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|
||||
|---------------------|------------------|--------------------------|----------------------------|----------|--------------------------|
|
||||
|------------------------|------------------|--------------------------|----------------------------|---------|-------------------------|
|
||||
| Dynamic Cache | No | No | No | Mid | No |
|
||||
| Static Cache | No | Yes | Yes | High | No |
|
||||
| Offloaded Cache | Yes | No | No | Low | Yes |
|
||||
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
|
||||
| Quantized Cache | Yes | No | No | Low | Yes |
|
||||
| Offloaded Cache | Yes | No | No | Low | No |
|
||||
| Sliding Window Cache | No | Yes | Yes | High | No |
|
||||
| Sink Cache | Yes | No | Yes | Mid | Yes |
|
||||
|
||||
@@ -154,7 +155,8 @@ Thus, it can serve as a drop-in replacement or a fallback for it.
|
||||
Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
|
||||
you may notice a small degradation in generation throughput compared to the default KV cache implementation.
|
||||
|
||||
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directky to the `generate()` call.
|
||||
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directly to the `generate()` call.
|
||||
Use `cache_implementation="offloaded_static"` for an offloaded static cache (see also [Offloaded Static Cache](#offloaded-static-cache) below).
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
@@ -216,7 +218,6 @@ retrying with cache_implementation='offloaded'
|
||||
before successfully generating 40 beams.
|
||||
|
||||
|
||||
|
||||
### Static Cache
|
||||
|
||||
Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
|
||||
@@ -238,6 +239,28 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC
|
||||
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
|
||||
```
|
||||
|
||||
|
||||
## Offloaded Static Cache
|
||||
|
||||
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. It fully supports
|
||||
JIT optimizations. Just pass `cache_implementation="offloaded_static"` in the `generation_config` or directly to the `generate()` call.
|
||||
This will use the [`~OffloadedStaticCache`] implementation instead.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
|
||||
>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)
|
||||
|
||||
>>> # simply pass the cache implementation="static"
|
||||
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="offloaded_static")
|
||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
|
||||
```
|
||||
|
||||
|
||||
### Sliding Window Cache
|
||||
|
||||
As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache.
|
||||
|
||||
@@ -1246,6 +1246,7 @@ else:
|
||||
"HybridCache",
|
||||
"MambaCache",
|
||||
"OffloadedCache",
|
||||
"OffloadedStaticCache",
|
||||
"QuantizedCache",
|
||||
"QuantizedCacheConfig",
|
||||
"QuantoQuantizedCache",
|
||||
@@ -6052,6 +6053,7 @@ if TYPE_CHECKING:
|
||||
HybridCache,
|
||||
MambaCache,
|
||||
OffloadedCache,
|
||||
OffloadedStaticCache,
|
||||
QuantizedCache,
|
||||
QuantizedCacheConfig,
|
||||
QuantoQuantizedCache,
|
||||
|
||||
@@ -1708,3 +1708,275 @@ class MambaCache:
|
||||
def reset(self):
|
||||
self.conv_states.zero_()
|
||||
self.ssm_states.zero_()
|
||||
|
||||
|
||||
class OffloadedStaticCache(StaticCache):
|
||||
"""
|
||||
Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
|
||||
another device.
|
||||
|
||||
Args:
|
||||
config (`PretrainedConfig):
|
||||
The configuration file defining the shape-related attributes required to initialize
|
||||
the static cache.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which the model will be used.
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`Union[str, torch.device]`):
|
||||
The device on which the cache should be initialized. Should be the same as the
|
||||
layer device.
|
||||
dtype (`torch.dtype`, *optional*):
|
||||
The default `dtype` to use when initializing the cache.
|
||||
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
|
||||
The device to offload to. Defaults to CPU.
|
||||
|
||||
Attributes:
|
||||
key_cache (`List[torch.Tensor]`):
|
||||
Off-loaded key cache tensors. First one will be on device, where-as the others are
|
||||
off-loaded.
|
||||
value_cache (`List[torch.Tensor]`):
|
||||
Off-loaded value cache tensors. First one will be on device, where-as the others are
|
||||
off-loaded.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which this cache can be used.
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which this cache can be used.
|
||||
device (`torch.device`):
|
||||
The device on which the cache is used.
|
||||
offload_device (`torch.device`):
|
||||
The device used to offload to.
|
||||
dtype (`torch.dtype`):
|
||||
The `dtype` used to initializing the cache.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
|
||||
|
||||
>>> # Prepare a cache class and pass it to model's forward
|
||||
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
||||
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
||||
>>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
max_cache_len: Optional[int],
|
||||
device: Union[str, torch.device],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
) -> None:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.device = torch.device(device)
|
||||
self.offload_device = torch.device(offload_device)
|
||||
self.dtype = dtype if dtype is not None else torch.float32
|
||||
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
|
||||
num_key_value_heads = (
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
)
|
||||
|
||||
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
|
||||
|
||||
# Create offloaded CPU tensors.
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
# First layer is always on-device.
|
||||
device = self.device if i == 0 else self.offload_device
|
||||
|
||||
key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device)
|
||||
|
||||
self.key_cache.append(key_cache)
|
||||
self.value_cache.append(value_cache)
|
||||
|
||||
# Create device tensors.
|
||||
self._device_key_cache: List[torch.Tensor] = []
|
||||
self._device_value_cache: List[torch.Tensor] = []
|
||||
|
||||
for i in range(2):
|
||||
key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device)
|
||||
|
||||
self._device_key_cache.append(key_cache)
|
||||
self._device_value_cache.append(value_cache)
|
||||
|
||||
# For backwards compatibility.
|
||||
# TODO(gante): Remove this.
|
||||
self._seen_tokens = 0
|
||||
|
||||
# Create new CUDA stream for parallel prefetching.
|
||||
self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
|
||||
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
cache_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the
|
||||
`cache_position` input to know how where to write in the cache.
|
||||
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
|
||||
if layer_idx == 0:
|
||||
# Update seen tokens.
|
||||
# TODO(gante): Remove this.
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
# Always there.
|
||||
k_out = self.key_cache[0]
|
||||
v_out = self.value_cache[0]
|
||||
else:
|
||||
# Wait for prefetch stream.
|
||||
if self._prefetch_stream is not None:
|
||||
torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream)
|
||||
|
||||
k_out = self._device_key_cache[layer_idx & 1]
|
||||
v_out = self._device_value_cache[layer_idx & 1]
|
||||
|
||||
self._prefetch_layer(layer_idx + 1)
|
||||
|
||||
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
|
||||
if cache_position is None:
|
||||
k_out.copy_(key_states)
|
||||
v_out.copy_(value_states)
|
||||
|
||||
# Copy the values to the offloaded device as well.
|
||||
if layer_idx == 0:
|
||||
self.key_cache[layer_idx].copy_(key_states.to(self.offload_device))
|
||||
self.value_cache[layer_idx].copy_(value_states.to(self.offload_device))
|
||||
else:
|
||||
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
|
||||
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does
|
||||
# explicitly an in-place operation, that avoids copies and uses less memory.
|
||||
try:
|
||||
k_out.index_copy_(2, cache_position, key_states)
|
||||
v_out.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS
|
||||
# device.
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
# Copy the values to the offloaded device as well.
|
||||
if layer_idx != 0:
|
||||
cache_position = cache_position.to(self.offload_device)
|
||||
key_states = key_states.to(self.offload_device)
|
||||
value_states = value_states.to(self.offload_device)
|
||||
|
||||
try:
|
||||
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
|
||||
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS
|
||||
# device.
|
||||
self.key_cache[layer_idx][:, :, cache_position] = key_states
|
||||
self.value_cache[layer_idx][:, :, cache_position] = value_states
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states that were seen by the model."""
|
||||
|
||||
# TODO(gante): Remove this.
|
||||
return self._seen_tokens
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length of the cached states."""
|
||||
|
||||
return self.max_cache_len
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the cache values while preserving the objects."""
|
||||
|
||||
# For backwards compatibility.
|
||||
# TODO(gante): Remove this.
|
||||
self._seen_tokens = 0
|
||||
|
||||
# Zero out cache.
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
# In-place ops prevent breaking the static address.
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
@property
|
||||
def seen_tokens(self) -> int:
|
||||
# For backwards compatibility.
|
||||
# TODO(gante): Remove this.
|
||||
return self._seen_tokens
|
||||
|
||||
def _create_key_value_cache_tensors(
|
||||
self, shape: Tuple[int, ...], device: torch.device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static
|
||||
addresses for non-CPU tensors.
|
||||
|
||||
Args:
|
||||
shape (`Tuple[int, ...]`): Shape.
|
||||
device (`torch.device`): Device.
|
||||
|
||||
Returns:
|
||||
Key and value cache tensors as a tuple.
|
||||
"""
|
||||
|
||||
is_cpu_device = device == torch.device("cpu")
|
||||
|
||||
key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
|
||||
value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
|
||||
|
||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||
# preventing compiled graph breaks when updating the cache.
|
||||
torch._dynamo.mark_static_address(key_cache)
|
||||
torch._dynamo.mark_static_address(value_cache)
|
||||
|
||||
return key_cache, value_cache
|
||||
|
||||
def _prefetch_layer(self, layer_idx: int) -> None:
|
||||
"""Prefetch a layer to the device. Needs to be called in order of layer indices."""
|
||||
|
||||
# Don't fetch layers that do not exist.
|
||||
if layer_idx >= len(self.key_cache):
|
||||
return
|
||||
|
||||
# Alternate between two on-device caches.
|
||||
if self._prefetch_stream is not None:
|
||||
with torch.cuda.stream(self._prefetch_stream):
|
||||
self._prefetch_layer_in_context(layer_idx)
|
||||
else:
|
||||
self._prefetch_layer_in_context(layer_idx)
|
||||
|
||||
def _prefetch_layer_in_context(self, layer_idx: int) -> None:
|
||||
"""Performs the actual copy of the layer to device cache."""
|
||||
|
||||
self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
|
||||
self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)
|
||||
|
||||
@@ -33,6 +33,7 @@ from ..cache_utils import (
|
||||
HybridCache,
|
||||
MambaCache,
|
||||
OffloadedCache,
|
||||
OffloadedStaticCache,
|
||||
QuantizedCacheConfig,
|
||||
QuantoQuantizedCache,
|
||||
SlidingWindowCache,
|
||||
@@ -119,6 +120,7 @@ if is_accelerate_available():
|
||||
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {
|
||||
"static": StaticCache,
|
||||
"offloaded_static": OffloadedStaticCache,
|
||||
"sliding_window": SlidingWindowCache,
|
||||
"hybrid": HybridCache,
|
||||
"mamba": MambaCache,
|
||||
|
||||
@@ -72,6 +72,13 @@ class OffloadedCache(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OffloadedStaticCache(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class QuantizedCache(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -380,8 +380,15 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(decoded[0].endswith(last_output))
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(["eager", "sdpa"])
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("eager", "static"),
|
||||
("sdpa", "static"),
|
||||
("eager", "offloaded-static"),
|
||||
("sdpa", "offloaded-static"),
|
||||
]
|
||||
)
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
@@ -406,7 +413,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.generation_config.cache_implementation = cache_implementation
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
@@ -420,8 +427,15 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(["eager", "sdpa"])
|
||||
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("eager", "static"),
|
||||
("sdpa", "static"),
|
||||
("eager", "offloaded-static"),
|
||||
("sdpa", "offloaded-static"),
|
||||
]
|
||||
)
|
||||
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color isЋ the one that complements the skin tone of",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
@@ -446,7 +460,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.generation_config.cache_implementation = cache_implementation
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
@@ -506,7 +520,13 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
def test_static_cache_extra_left_padding(self):
|
||||
@parameterized.expand(
|
||||
[
|
||||
"static",
|
||||
"offloaded-static",
|
||||
]
|
||||
)
|
||||
def test_static_cache_extra_left_padding(self, cache_implementation):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
@@ -524,7 +544,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.generation_config.cache_implementation = cache_implementation
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
|
||||
Reference in New Issue
Block a user