[core] Refactor the Cache logic to make it simpler and more general (#39797)
* Simplify the logic quite a bit * Update cache_utils.py * continue work * continue simplifying a lot * style * Update cache_utils.py * offloading much simpler * style * Update cache_utils.py * update inits * Update cache_utils.py * consistemncy * Update cache_utils.py * update generate * style * fix * fix * add early_initialization * fix * fix mamba caches * update * fix * fix * fix * fix tests * fix configs * revert * fix tests * alright * Update modeling_gptj.py * fix the constructors * cache tests * Update test_cache_utils.py * fix * simplify * back to before -> avoid compile bug * doc * mistral test * llama4 test dtype * Update test_modeling_llama4.py * CIs * Finally find a nice impl * Update cache_utils.py * Update cache_utils.py * add lazy methods in autodoc * typo * better doc * Add detailed docstring for lazy init * CIs * style * fix
This commit is contained in:
@@ -363,37 +363,34 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
- get_max_cache_shape
|
||||
- reset
|
||||
- reorder_cache
|
||||
- lazy_initialization
|
||||
|
||||
[[autodoc]] DynamicLayer
|
||||
- update
|
||||
- lazy_initialization
|
||||
- crop
|
||||
- batch_repeat_interleave
|
||||
- batch_select_indices
|
||||
|
||||
[[autodoc]] StaticLayer
|
||||
- update
|
||||
- lazy_initialization
|
||||
|
||||
[[autodoc]] SlidingWindowLayer
|
||||
- update
|
||||
- lazy_initialization
|
||||
|
||||
[[autodoc]] CacheProcessor
|
||||
- pre_update
|
||||
- post_update
|
||||
[[autodoc]] QuantoQuantizedLayer
|
||||
- update
|
||||
- lazy_initialization
|
||||
|
||||
[[autodoc]] OffloadedCacheProcessor
|
||||
- pre_update
|
||||
|
||||
[[autodoc]] QuantizedCacheProcessor
|
||||
- post_update
|
||||
|
||||
[[autodoc]] QuantoQuantizedCacheProcessor
|
||||
- post_update
|
||||
|
||||
[[autodoc]] HQQQuantizedCacheProcessor
|
||||
- post_update
|
||||
[[autodoc]] HQQQuantizedLayer
|
||||
- update
|
||||
- lazy_initialization
|
||||
|
||||
[[autodoc]] Cache
|
||||
- update
|
||||
- early_initialization
|
||||
- get_seq_length
|
||||
- get_mask_sizes
|
||||
- get_max_cache_shape
|
||||
@@ -411,12 +408,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] QuantoQuantizedCache
|
||||
|
||||
[[autodoc]] QuantoQuantizedCacheProcessor
|
||||
|
||||
[[autodoc]] HQQQuantizedCache
|
||||
|
||||
[[autodoc]] HQQQuantizedCacheProcessor
|
||||
|
||||
[[autodoc]] OffloadedCache
|
||||
|
||||
[[autodoc]] StaticCache
|
||||
|
||||
@@ -312,7 +312,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Init StaticCache with big enough max-length (1024 tokens for the below example)
|
||||
# You can also init a DynamicCache, if that suits you better
|
||||
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device=model.device.type, dtype=torch.bfloat16)
|
||||
prompt_cache = StaticCache(config=model.config, max_cache_len=1024)
|
||||
|
||||
INITIAL_PROMPT = "You are a helpful assistant. "
|
||||
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(model.device.type)
|
||||
|
||||
@@ -93,11 +93,8 @@ model.generation_config.max_new_tokens = 16
|
||||
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
max_batch_size=1,
|
||||
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
|
||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||
device=model.device,
|
||||
dtype=model.dtype
|
||||
)
|
||||
outputs = model.generate(**input_ids, past_key_values=past_key_values)
|
||||
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
@@ -159,7 +156,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||
config=model.config, max_cache_len=4096
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
|
||||
@@ -138,8 +138,7 @@ visualizer("You are an assistant. Make sure you print me")
|
||||
|
||||
inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
|
||||
max_generated_length = inputs.input_ids.shape[1] + 10
|
||||
past_key_values = HybridCache(config=model.config, max_batch_size=1,
|
||||
max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
||||
past_key_values = HybridCache(config=model.config, max_cache_len=max_generated_length)
|
||||
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
```
|
||||
|
||||
|
||||
@@ -362,21 +362,11 @@ generation_output[:2]
|
||||
[[autodoc]] SlidingWindowLayer
|
||||
- update
|
||||
|
||||
[[autodoc]] CacheProcessor
|
||||
- pre_update
|
||||
- post_update
|
||||
[[autodoc]] QuantoQuantizedLayer
|
||||
- update
|
||||
|
||||
[[autodoc]] OffloadedCacheProcessor
|
||||
- pre_update
|
||||
|
||||
[[autodoc]] QuantizedCacheProcessor
|
||||
- post_update
|
||||
|
||||
[[autodoc]] QuantoQuantizedCacheProcessor
|
||||
- post_update
|
||||
|
||||
[[autodoc]] HQQQuantizedCacheProcessor
|
||||
- post_update
|
||||
[[autodoc]] HQQQuantizedLayer
|
||||
- update
|
||||
|
||||
[[autodoc]] Cache
|
||||
- update
|
||||
@@ -397,12 +387,8 @@ generation_output[:2]
|
||||
|
||||
[[autodoc]] QuantoQuantizedCache
|
||||
|
||||
[[autodoc]] QuantoQuantizedCacheProcessor
|
||||
|
||||
[[autodoc]] HQQQuantizedCache
|
||||
|
||||
[[autodoc]] HQQQuantizedCacheProcessor
|
||||
|
||||
[[autodoc]] OffloadedCache
|
||||
|
||||
[[autodoc]] StaticCache
|
||||
|
||||
@@ -99,11 +99,8 @@ model.generation_config.max_new_tokens = 16
|
||||
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
max_batch_size=1,
|
||||
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
|
||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||
device=model.device,
|
||||
dtype=model.dtype
|
||||
)
|
||||
outputs = model.generate(**input_ids, past_key_values=past_key_values)
|
||||
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
@@ -161,7 +158,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||
config=model.config, max_cache_len=4096
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
|
||||
Reference in New Issue
Block a user