Add a static cache that offloads to the CPU or other device (#32161)

* Add a static cache that offloads to the CPU or other device

* Fix PR comments, add unit-tests
This commit is contained in:
Gerben van V
2024-08-29 11:51:09 +02:00
committed by GitHub
parent 92a75ff6b1
commit 5129671290
7 changed files with 350 additions and 19 deletions

View File

@@ -390,6 +390,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length
- reset
[[autodoc]] OffloadedStaticCache
- update
- get_seq_length
- reset
[[autodoc]] HybridCache
- update
- get_seq_length

View File

@@ -96,14 +96,15 @@ with the [`~DynamicCache`] class being the default cache for most models. It all
Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case.
| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|---------------------|------------------|--------------------------|----------------------------|----------|--------------------------|
| Dynamic Cache | No | No | No | Mid | No |
| Static Cache | No | Yes | Yes | High | No |
| Quantized Cache | Yes | No | No | Low | Yes |
| Offloaded Cache | Yes | No | No | Low | No |
| Sliding Window Cache| No | Yes | Yes | High | No |
| Sink Cache | Yes | No | Yes | Mid | Yes |
| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|------------------------|------------------|--------------------------|----------------------------|---------|-------------------------|
| Dynamic Cache | No | No | No | Mid | No |
| Static Cache | No | Yes | Yes | High | No |
| Offloaded Cache | Yes | No | No | Low | Yes |
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
| Quantized Cache | Yes | No | No | Low | Yes |
| Sliding Window Cache | No | Yes | Yes | High | No |
| Sink Cache | Yes | No | Yes | Mid | Yes |
These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation.md#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details.
@@ -142,7 +143,7 @@ I like rock music because it's loud and energetic. It's a great way to express m
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```
## OffloadedCache
## Offloaded Cache
Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage.
It does so by moving the KV cache for most layers to the CPU.
@@ -154,7 +155,8 @@ Thus, it can serve as a drop-in replacement or a fallback for it.
Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
you may notice a small degradation in generation throughput compared to the default KV cache implementation.
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directky to the `generate()` call.
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directly to the `generate()` call.
Use `cache_implementation="offloaded_static"` for an offloaded static cache (see also [Offloaded Static Cache](#offloaded-static-cache) below).
```python
>>> import torch
@@ -216,7 +218,6 @@ retrying with cache_implementation='offloaded'
before successfully generating 40 beams.
### Static Cache
Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
@@ -238,6 +239,28 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```
## Offloaded Static Cache
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. It fully supports
JIT optimizations. Just pass `cache_implementation="offloaded_static"` in the `generation_config` or directly to the `generate()` call.
This will use the [`~OffloadedStaticCache`] implementation instead.
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)
>>> # simply pass the cache implementation="static"
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="offloaded_static")
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```
### Sliding Window Cache
As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache.

View File

@@ -1246,6 +1246,7 @@ else:
"HybridCache",
"MambaCache",
"OffloadedCache",
"OffloadedStaticCache",
"QuantizedCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
@@ -6052,6 +6053,7 @@ if TYPE_CHECKING:
HybridCache,
MambaCache,
OffloadedCache,
OffloadedStaticCache,
QuantizedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,

View File

@@ -1708,3 +1708,275 @@ class MambaCache:
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
class OffloadedStaticCache(StaticCache):
"""
Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
another device.
Args:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize
the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`Union[str, torch.device]`):
The device on which the cache should be initialized. Should be the same as the
layer device.
dtype (`torch.dtype`, *optional*):
The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
The device to offload to. Defaults to CPU.
Attributes:
key_cache (`List[torch.Tensor]`):
Off-loaded key cache tensors. First one will be on device, where-as the others are
off-loaded.
value_cache (`List[torch.Tensor]`):
Off-loaded value cache tensors. First one will be on device, where-as the others are
off-loaded.
max_batch_size (`int`):
The maximum batch size with which this cache can be used.
max_cache_len (`int`):
The maximum sequence length with which this cache can be used.
device (`torch.device`):
The device on which the cache is used.
offload_device (`torch.device`):
The device used to offload to.
dtype (`torch.dtype`):
The `dtype` used to initializing the cache.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: Optional[int],
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"),
) -> None:
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device)
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
# Create offloaded CPU tensors.
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
for i in range(config.num_hidden_layers):
# First layer is always on-device.
device = self.device if i == 0 else self.offload_device
key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device)
self.key_cache.append(key_cache)
self.value_cache.append(value_cache)
# Create device tensors.
self._device_key_cache: List[torch.Tensor] = []
self._device_value_cache: List[torch.Tensor] = []
for i in range(2):
key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device)
self._device_key_cache.append(key_cache)
self._device_value_cache.append(value_cache)
# For backwards compatibility.
# TODO(gante): Remove this.
self._seen_tokens = 0
# Create new CUDA stream for parallel prefetching.
self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the
`cache_position` input to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
if layer_idx == 0:
# Update seen tokens.
# TODO(gante): Remove this.
self._seen_tokens += key_states.shape[-2]
# Always there.
k_out = self.key_cache[0]
v_out = self.value_cache[0]
else:
# Wait for prefetch stream.
if self._prefetch_stream is not None:
torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream)
k_out = self._device_key_cache[layer_idx & 1]
v_out = self._device_value_cache[layer_idx & 1]
self._prefetch_layer(layer_idx + 1)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
# Copy the values to the offloaded device as well.
if layer_idx == 0:
self.key_cache[layer_idx].copy_(key_states.to(self.offload_device))
self.value_cache[layer_idx].copy_(value_states.to(self.offload_device))
else:
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does
# explicitly an in-place operation, that avoids copies and uses less memory.
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS
# device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# Copy the values to the offloaded device as well.
if layer_idx != 0:
cache_position = cache_position.to(self.offload_device)
key_states = key_states.to(self.offload_device)
value_states = value_states.to(self.offload_device)
try:
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS
# device.
self.key_cache[layer_idx][:, :, cache_position] = key_states
self.value_cache[layer_idx][:, :, cache_position] = value_states
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# TODO(gante): Remove this.
return self._seen_tokens
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len
def reset(self) -> None:
"""Resets the cache values while preserving the objects."""
# For backwards compatibility.
# TODO(gante): Remove this.
self._seen_tokens = 0
# Zero out cache.
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address.
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
@property
def seen_tokens(self) -> int:
# For backwards compatibility.
# TODO(gante): Remove this.
return self._seen_tokens
def _create_key_value_cache_tensors(
self, shape: Tuple[int, ...], device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static
addresses for non-CPU tensors.
Args:
shape (`Tuple[int, ...]`): Shape.
device (`torch.device`): Device.
Returns:
Key and value cache tensors as a tuple.
"""
is_cpu_device = device == torch.device("cpu")
key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(value_cache)
return key_cache, value_cache
def _prefetch_layer(self, layer_idx: int) -> None:
"""Prefetch a layer to the device. Needs to be called in order of layer indices."""
# Don't fetch layers that do not exist.
if layer_idx >= len(self.key_cache):
return
# Alternate between two on-device caches.
if self._prefetch_stream is not None:
with torch.cuda.stream(self._prefetch_stream):
self._prefetch_layer_in_context(layer_idx)
else:
self._prefetch_layer_in_context(layer_idx)
def _prefetch_layer_in_context(self, layer_idx: int) -> None:
"""Performs the actual copy of the layer to device cache."""
self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)

View File

@@ -33,6 +33,7 @@ from ..cache_utils import (
HybridCache,
MambaCache,
OffloadedCache,
OffloadedStaticCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
@@ -119,6 +120,7 @@ if is_accelerate_available():
NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
"offloaded_static": OffloadedStaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"mamba": MambaCache,

View File

@@ -72,6 +72,13 @@ class OffloadedCache(metaclass=DummyObject):
requires_backends(self, ["torch"])
class OffloadedStaticCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QuantizedCache(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -380,8 +380,15 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertTrue(decoded[0].endswith(last_output))
@require_torch_gpu
@parameterized.expand(["eager", "sdpa"])
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
@parameterized.expand(
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
@@ -406,7 +413,7 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
@@ -420,8 +427,15 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertListEqual(decoded, EXPECTED_GENERATION)
@require_torch_gpu
@parameterized.expand(["eager", "sdpa"])
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
@parameterized.expand(
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
EXPECTED_GENERATION = [
"The best color isЋ the one that complements the skin tone of",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
@@ -446,7 +460,7 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
@@ -506,7 +520,13 @@ class CacheIntegrationTest(unittest.TestCase):
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
def test_static_cache_extra_left_padding(self):
@parameterized.expand(
[
"static",
"offloaded-static",
]
)
def test_static_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
@@ -524,7 +544,7 @@ class CacheIntegrationTest(unittest.TestCase):
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
model.generation_config.cache_implementation = "static"
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)