Cache: use batch_size instead of max_batch_size (#32657)
* more precise name * better docstrings * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -47,12 +47,12 @@ if is_torch_available():
|
||||
end_of_text_token = 32000
|
||||
|
||||
class Phi3MiniWithStaticCache(torch.nn.Module):
|
||||
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
|
||||
def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.cache = StaticCache(
|
||||
config=model.config,
|
||||
max_batch_size=max_batch_size,
|
||||
batch_size=batch_size,
|
||||
max_cache_len=max_seq_len,
|
||||
device=self.model.device,
|
||||
dtype=self.model.dtype,
|
||||
|
||||
Reference in New Issue
Block a user