Cache: use batch_size instead of max_batch_size (#32657)

* more precise name

* better docstrings

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-08-16 11:48:45 +01:00
committed by GitHub
parent 8f9fa3b081
commit cf32ee1753
9 changed files with 112 additions and 54 deletions

View File

@@ -1040,7 +1040,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,
@@ -1088,7 +1088,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,