Add a static cache that offloads to the CPU or other device (#32161)
* Add a static cache that offloads to the CPU or other device * Fix PR comments, add unit-tests
This commit is contained in:
@@ -390,6 +390,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
- get_seq_length
|
||||
- reset
|
||||
|
||||
[[autodoc]] OffloadedStaticCache
|
||||
- update
|
||||
- get_seq_length
|
||||
- reset
|
||||
|
||||
[[autodoc]] HybridCache
|
||||
- update
|
||||
- get_seq_length
|
||||
|
||||
@@ -96,14 +96,15 @@ with the [`~DynamicCache`] class being the default cache for most models. It all
|
||||
|
||||
Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case.
|
||||
|
||||
| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|
||||
|---------------------|------------------|--------------------------|----------------------------|----------|--------------------------|
|
||||
| Dynamic Cache | No | No | No | Mid | No |
|
||||
| Static Cache | No | Yes | Yes | High | No |
|
||||
| Quantized Cache | Yes | No | No | Low | Yes |
|
||||
| Offloaded Cache | Yes | No | No | Low | No |
|
||||
| Sliding Window Cache| No | Yes | Yes | High | No |
|
||||
| Sink Cache | Yes | No | Yes | Mid | Yes |
|
||||
| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|
||||
|------------------------|------------------|--------------------------|----------------------------|---------|-------------------------|
|
||||
| Dynamic Cache | No | No | No | Mid | No |
|
||||
| Static Cache | No | Yes | Yes | High | No |
|
||||
| Offloaded Cache | Yes | No | No | Low | Yes |
|
||||
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
|
||||
| Quantized Cache | Yes | No | No | Low | Yes |
|
||||
| Sliding Window Cache | No | Yes | Yes | High | No |
|
||||
| Sink Cache | Yes | No | Yes | Mid | Yes |
|
||||
|
||||
|
||||
These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation.md#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details.
|
||||
@@ -142,7 +143,7 @@ I like rock music because it's loud and energetic. It's a great way to express m
|
||||
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
|
||||
```
|
||||
|
||||
## OffloadedCache
|
||||
## Offloaded Cache
|
||||
|
||||
Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage.
|
||||
It does so by moving the KV cache for most layers to the CPU.
|
||||
@@ -154,7 +155,8 @@ Thus, it can serve as a drop-in replacement or a fallback for it.
|
||||
Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
|
||||
you may notice a small degradation in generation throughput compared to the default KV cache implementation.
|
||||
|
||||
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directky to the `generate()` call.
|
||||
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directly to the `generate()` call.
|
||||
Use `cache_implementation="offloaded_static"` for an offloaded static cache (see also [Offloaded Static Cache](#offloaded-static-cache) below).
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
@@ -216,7 +218,6 @@ retrying with cache_implementation='offloaded'
|
||||
before successfully generating 40 beams.
|
||||
|
||||
|
||||
|
||||
### Static Cache
|
||||
|
||||
Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
|
||||
@@ -238,6 +239,28 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC
|
||||
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
|
||||
```
|
||||
|
||||
|
||||
## Offloaded Static Cache
|
||||
|
||||
Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. It fully supports
|
||||
JIT optimizations. Just pass `cache_implementation="offloaded_static"` in the `generation_config` or directly to the `generate()` call.
|
||||
This will use the [`~OffloadedStaticCache`] implementation instead.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
|
||||
>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)
|
||||
|
||||
>>> # simply pass the cache implementation="static"
|
||||
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="offloaded_static")
|
||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
|
||||
```
|
||||
|
||||
|
||||
### Sliding Window Cache
|
||||
|
||||
As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache.
|
||||
|
||||
Reference in New Issue
Block a user