Cache: use batch_size instead of max_batch_size (#32657)
* more precise name * better docstrings * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -145,7 +145,7 @@ class CacheTest(unittest.TestCase):
|
||||
return random_keys, random_values
|
||||
|
||||
mha_config = LlamaConfig(num_attention_heads=32)
|
||||
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mha_static_cache.update(
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
@@ -153,7 +153,7 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
||||
|
||||
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
||||
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = gqa_static_cache.update(
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
@@ -161,7 +161,7 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
||||
|
||||
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
||||
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mqa_static_cache.update(
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
@@ -181,7 +181,7 @@ class CacheTest(unittest.TestCase):
|
||||
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
max_batch_size = 1
|
||||
batch_size = 1
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
"google/gemma-2b",
|
||||
@@ -203,7 +203,7 @@ class CacheTest(unittest.TestCase):
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.static_cache = StaticCache(
|
||||
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
|
||||
config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
|
||||
)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
|
||||
|
||||
Reference in New Issue
Block a user