Cache: use batch_size instead of max_batch_size (#32657)

* more precise name

* better docstrings

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-08-16 11:48:45 +01:00
committed by GitHub
parent 8f9fa3b081
commit cf32ee1753
9 changed files with 112 additions and 54 deletions

View File

@@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
@@ -161,7 +161,7 @@ There are a few important things you must do to enable static kv-cache and `torc
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(