[generate] move SinkCache to a custom_generate repo (#38399)
remove sink cache
This commit is contained in:
@@ -380,11 +380,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] HQQQuantizedCache
|
||||
|
||||
[[autodoc]] SinkCache
|
||||
- update
|
||||
- get_seq_length
|
||||
- reorder_cache
|
||||
|
||||
[[autodoc]] OffloadedCache
|
||||
- update
|
||||
- prefetch_layer
|
||||
@@ -443,4 +438,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] CompileConfig
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ Transformers offers several [`Cache`] classes that implement different caching m
|
||||
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
|
||||
| Quantized Cache | Yes | No | No | Low | Yes |
|
||||
| Sliding Window Cache | No | Yes | Yes | High | No |
|
||||
| Sink Cache | Yes | No | Yes | Mid | Yes |
|
||||
|
||||
This guide introduces you to the different [`Cache`] classes and shows you how to use them for generation.
|
||||
|
||||
@@ -174,28 +173,6 @@ I like rock music because it's loud and energetic. It's a great way to express m
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Sink cache
|
||||
|
||||
[`SinkCache`] is capable of generating very long sequences ("infinite length" according to the paper) by only retaining a few initial tokens from the sequence. These are called the *sink tokens* because they account for a significant portion of the attention scores during generation. Subsequent tokens are discarded on a sliding windowed basis, and only the latest `window_size` tokens are kept. This means most of the previous knowledge is discarded.
|
||||
|
||||
The sink tokens allow a model to maintain stable performance even when it's dealing with very long text sequences.
|
||||
|
||||
Enable [`SinkCache`] by initializing it first with the [window_length](https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.window_length) and [num_sink_tokens](https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.num_sink_tokens) parameters before passing it to [past_key_values](https://hf.co/docs/transformers/internal/generation_utils#transformers.generation.GenerateDecoderOnlyOutput.past_key_values) in [`~GenerationMixin.generate`].
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
|
||||
inputs = tokenizer("This is a long story about unicorns, fairies and magic.", return_tensors="pt").to(model.device)
|
||||
|
||||
past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
|
||||
out = model.generate(**inputs, do_sample=False, max_new_tokens=30, past_key_values=past_key_values)
|
||||
tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
"This is a long story about unicorns, fairies and magic. It is a fantasy world where unicorns and fairies live together in harmony. The story follows a young girl named Lily"
|
||||
```
|
||||
|
||||
## Speed optimized caches
|
||||
|
||||
The default [`DynamicCache`] prevents you from taking advantage of just-in-time (JIT) optimizations because the cache size isn't fixed. JIT optimizations enable you to maximize latency at the expense of memory usage. All of the following cache types are compatible with JIT optimizations like [torch.compile](./llm_optims#static-kv-cache-and-torchcompile) to accelerate generation.
|
||||
@@ -247,7 +224,7 @@ Enable [`SlidingWindowCache`] by configuring `cache_implementation="sliding_wind
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16).to("cuda:0")
|
||||
@@ -284,8 +261,6 @@ A cache can also work in iterative generation settings where there is back-and-f
|
||||
|
||||
For iterative generation with a cache, start by initializing an empty cache class and then you can feed in your new prompts. Keep track of dialogue history with a [chat template](./chat_templating).
|
||||
|
||||
If you're using [`SinkCache`], the inputs need to be truncated to the maximum length because [`SinkCache`] can generate text that exceeds its maximum window size. However, the first input shouldn't exceed the maximum cache length.
|
||||
|
||||
The example below demonstrates how to use a cache for iterative generation.
|
||||
|
||||
```py
|
||||
@@ -293,7 +268,6 @@ import torch
|
||||
from transformers import AutoTokenizer,AutoModelForCausalLM
|
||||
from transformers.cache_utils import (
|
||||
DynamicCache,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
SlidingWindowCache,
|
||||
QuantoQuantizedCache,
|
||||
@@ -313,8 +287,6 @@ messages = []
|
||||
for prompt in user_prompts:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
|
||||
if isinstance(past_key_values, SinkCache):
|
||||
inputs = {k: v[:, -max_cache_length:] for k, v in inputs.items()}
|
||||
input_length = inputs["input_ids"].shape[1]
|
||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=256, past_key_values=past_key_values)
|
||||
completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True)
|
||||
@@ -336,7 +308,7 @@ model_id = "meta-llama/Llama-2-7b-chat-hf"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Init StaticCache with big enough max-length (1024 tokens for the below example)
|
||||
# Init StaticCache with big enough max-length (1024 tokens for the below example)
|
||||
# You can also init a DynamicCache, if that suits you better
|
||||
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
@@ -351,7 +323,7 @@ responses = []
|
||||
for prompt in prompts:
|
||||
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
|
||||
past_key_values = copy.deepcopy(prompt_cache)
|
||||
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
|
||||
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
|
||||
response = tokenizer.batch_decode(outputs)[0]
|
||||
responses.append(response)
|
||||
|
||||
|
||||
@@ -366,11 +366,6 @@ generation_output[:2]
|
||||
|
||||
[[autodoc]] HQQQuantizedCache
|
||||
|
||||
[[autodoc]] SinkCache
|
||||
- update
|
||||
- get_seq_length
|
||||
- reorder_cache
|
||||
|
||||
[[autodoc]] OffloadedCache
|
||||
- update
|
||||
- prefetch_layer
|
||||
|
||||
Reference in New Issue
Block a user