From 3e70a207dfa6408c440042f2f8076dd6bfb43e8b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 13 Feb 2024 09:58:19 +0000 Subject: [PATCH] Static Cache: load models with MQA or GQA (#28975) --- src/transformers/cache_utils.py | 6 +++-- tests/test_cache_utils.py | 46 ++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 8ac6619bf6..22d0e44b2d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -351,10 +351,12 @@ class StaticCache(Cache): self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.head_dim = config.hidden_size // config.num_attention_heads - self.num_heads = config.num_attention_heads + self.num_key_value_heads = ( + config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + ) self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype - cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim) + cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.seen_tokens = 0 diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index df6b15f4dc..c6a07bb268 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -35,14 +35,16 @@ if is_torch_available(): AutoModelForCausalLM, AutoTokenizer, DynamicCache, + LlamaConfig, LlamaForCausalLM, SinkCache, + StaticCache, ) @require_torch class CacheTest(unittest.TestCase): - def test_cache_equivalence(self): + def test_dynamic_cache_retrocompatibility(self): """Tests that we can convert back and forth between the legacy cache format and DynamicCache""" legacy_cache = () new_cache = DynamicCache() @@ -120,6 +122,48 @@ class CacheTest(unittest.TestCase): ) ) + def test_static_cache_mha_mqa_gqa(self): + """ + Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query + attention (MQA) + """ + + def _random_kvs(config): + # shape for key and values: (batch_size, num_heads, seq_len, head_dim) + random_keys = torch.rand( + (1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads), + device=torch_device, + ) + random_values = torch.rand( + (1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads), + device=torch_device, + ) + return random_keys, random_values + + mha_config = LlamaConfig(num_attention_heads=32) + mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) + cached_keys, cached_values = mha_static_cache.update( + *_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)} + ) + self.assertTrue(cached_keys.shape == (1, 32, 10, 128)) + self.assertTrue(cached_values.shape == (1, 32, 10, 128)) + + gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) + gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + cached_keys, cached_values = gqa_static_cache.update( + *_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)} + ) + self.assertTrue(cached_keys.shape == (1, 4, 10, 128)) + self.assertTrue(cached_values.shape == (1, 4, 10, 128)) + + mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) + mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + cached_keys, cached_values = mqa_static_cache.update( + *_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)} + ) + self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) + self.assertTrue(cached_values.shape == (1, 1, 10, 128)) + @require_torch_gpu @slow