Add a static cache that offloads to the CPU or other device (#32161)

* Add a static cache that offloads to the CPU or other device

* Fix PR comments, add unit-tests
This commit is contained in:
Gerben van V
2024-08-29 11:51:09 +02:00
committed by GitHub
parent 92a75ff6b1
commit 5129671290
7 changed files with 350 additions and 19 deletions

View File

@@ -390,6 +390,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length
- reset
[[autodoc]] OffloadedStaticCache
- update
- get_seq_length
- reset
[[autodoc]] HybridCache
- update
- get_seq_length

View File

@@ -96,14 +96,15 @@ with the [`~DynamicCache`] class being the default cache for most models. It all
Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case.
| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|---------------------|------------------|--------------------------|----------------------------|----------|--------------------------|
| Dynamic Cache | No | No | No | Mid | No |
| Static Cache | No | Yes | Yes | High | No |
| Quantized Cache | Yes | No | No | Low | Yes |
| Offloaded Cache | Yes | No | No | Low | No |
| Sliding Window Cache| No | Yes | Yes | High | No |
| Sink Cache | Yes | No | Yes | Mid | Yes |
| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|------------------------|------------------|--------------------------|----------------------------|---------|-------------------------|
| Dynamic Cache | No | No | No | Mid | No |
| Static Cache | No | Yes | Yes | High | No |
| Offloaded Cache | Yes | No | No | Low | Yes |
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
| Quantized Cache | Yes | No | No | Low | Yes |
| Sliding Window Cache | No | Yes | Yes | High | No |
| Sink Cache | Yes | No | Yes | Mid | Yes |
These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation.md#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details.
@@ -142,7 +143,7 @@ I like rock music because it's loud and energetic. It's a great way to express m
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```
## OffloadedCache
## Offloaded Cache
Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage.
It does so by moving the KV cache for most layers to the CPU.
@@ -154,7 +155,8 @@ Thus, it can serve as a drop-in replacement or a fallback for it.
Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
you may notice a small degradation in generation throughput compared to the default KV cache implementation.
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directky to the `generate()` call.
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directly to the `generate()` call.
Use `cache_implementation="offloaded_static"` for an offloaded static cache (see also [Offloaded Static Cache](#offloaded-static-cache) below).
```python
>>> import torch
@@ -216,7 +218,6 @@ retrying with cache_implementation='offloaded'
before successfully generating 40 beams.
### Static Cache
Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
@@ -238,6 +239,28 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```
## Offloaded Static Cache
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. It fully supports
JIT optimizations. Just pass `cache_implementation="offloaded_static"` in the `generation_config` or directly to the `generate()` call.
This will use the [`~OffloadedStaticCache`] implementation instead.
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)
>>> # simply pass the cache implementation="static"
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="offloaded_static")
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```
### Sliding Window Cache
As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache.