Cache: use batch_size instead of max_batch_size (#32657)
* more precise name * better docstrings * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
|
||||
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
max_batch_size=1,
|
||||
batch_size=1,
|
||||
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
|
||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||
device=model.device,
|
||||
@@ -161,7 +161,7 @@ There are a few important things you must do to enable static kv-cache and `torc
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
|
||||
Reference in New Issue
Block a user