[core] Refactor the Cache logic to make it simpler and more general (#39797)

* Simplify the logic quite a bit

* Update cache_utils.py

* continue work

* continue simplifying a lot

* style

* Update cache_utils.py

* offloading much simpler

* style

* Update cache_utils.py

* update inits

* Update cache_utils.py

* consistemncy

* Update cache_utils.py

* update generate

* style

* fix

* fix

* add early_initialization

* fix

* fix mamba caches

* update

* fix

* fix

* fix

* fix tests

* fix configs

* revert

* fix tests

* alright

* Update modeling_gptj.py

* fix the constructors

* cache tests

* Update test_cache_utils.py

* fix

* simplify

* back to before -> avoid compile bug

* doc

* mistral test

* llama4 test dtype

* Update test_modeling_llama4.py

* CIs

* Finally find a nice impl

* Update cache_utils.py

* Update cache_utils.py

* add lazy methods in autodoc

* typo

* better doc

* Add detailed docstring for lazy init

* CIs

* style

* fix
This commit is contained in:
Cyril Vallez
2025-08-08 14:47:21 +02:00
committed by GitHub
parent 95510ab018
commit dc11a3cbb2
48 changed files with 771 additions and 1441 deletions

View File

@@ -363,37 +363,34 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_max_cache_shape
- reset
- reorder_cache
- lazy_initialization
[[autodoc]] DynamicLayer
- update
- lazy_initialization
- crop
- batch_repeat_interleave
- batch_select_indices
[[autodoc]] StaticLayer
- update
- lazy_initialization
[[autodoc]] SlidingWindowLayer
- update
- lazy_initialization
[[autodoc]] CacheProcessor
- pre_update
- post_update
[[autodoc]] QuantoQuantizedLayer
- update
- lazy_initialization
[[autodoc]] OffloadedCacheProcessor
- pre_update
[[autodoc]] QuantizedCacheProcessor
- post_update
[[autodoc]] QuantoQuantizedCacheProcessor
- post_update
[[autodoc]] HQQQuantizedCacheProcessor
- post_update
[[autodoc]] HQQQuantizedLayer
- update
- lazy_initialization
[[autodoc]] Cache
- update
- early_initialization
- get_seq_length
- get_mask_sizes
- get_max_cache_shape
@@ -411,12 +408,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] QuantoQuantizedCache
[[autodoc]] QuantoQuantizedCacheProcessor
[[autodoc]] HQQQuantizedCache
[[autodoc]] HQQQuantizedCacheProcessor
[[autodoc]] OffloadedCache
[[autodoc]] StaticCache

View File

@@ -312,7 +312,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
# Init StaticCache with big enough max-length (1024 tokens for the below example)
# You can also init a DynamicCache, if that suits you better
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device=model.device.type, dtype=torch.bfloat16)
prompt_cache = StaticCache(config=model.config, max_cache_len=1024)
INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(model.device.type)

View File

@@ -93,11 +93,8 @@ model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
@@ -159,7 +156,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, max_cache_len=4096
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(

View File

@@ -138,8 +138,7 @@ visualizer("You are an assistant. Make sure you print me")
inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
max_generated_length = inputs.input_ids.shape[1] + 10
past_key_values = HybridCache(config=model.config, max_batch_size=1,
max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
past_key_values = HybridCache(config=model.config, max_cache_len=max_generated_length)
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
```

View File

@@ -362,21 +362,11 @@ generation_output[:2]
[[autodoc]] SlidingWindowLayer
- update
[[autodoc]] CacheProcessor
- pre_update
- post_update
[[autodoc]] QuantoQuantizedLayer
- update
[[autodoc]] OffloadedCacheProcessor
- pre_update
[[autodoc]] QuantizedCacheProcessor
- post_update
[[autodoc]] QuantoQuantizedCacheProcessor
- post_update
[[autodoc]] HQQQuantizedCacheProcessor
- post_update
[[autodoc]] HQQQuantizedLayer
- update
[[autodoc]] Cache
- update
@@ -397,12 +387,8 @@ generation_output[:2]
[[autodoc]] QuantoQuantizedCache
[[autodoc]] QuantoQuantizedCacheProcessor
[[autodoc]] HQQQuantizedCache
[[autodoc]] HQQQuantizedCacheProcessor
[[autodoc]] OffloadedCache
[[autodoc]] StaticCache

View File

@@ -99,11 +99,8 @@ model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
@@ -161,7 +158,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, max_cache_len=4096
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(