Remove deprecated batch_size parameter (#37007)

This commit is contained in:
cyyever
2025-03-27 23:01:56 +08:00
committed by GitHub
parent 4cc65e990f
commit 6cc9c8d7d1
4 changed files with 50 additions and 97 deletions

View File

@@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase):
return random_keys, random_values
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
@@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
@@ -167,7 +167,7 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)