Cache: use batch_size instead of max_batch_size (#32657)

* more precise name

* better docstrings

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-08-16 11:48:45 +01:00
committed by GitHub
parent 8f9fa3b081
commit cf32ee1753
9 changed files with 112 additions and 54 deletions

View File

@@ -47,12 +47,12 @@ if is_torch_available():
end_of_text_token = 32000
class Phi3MiniWithStaticCache(torch.nn.Module):
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = StaticCache(
config=model.config,
max_batch_size=max_batch_size,
batch_size=batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,