Fix static generation when compiling! (#28937)
* wow I was scared! * fix everything * nits * make it BC? * add todo * nits * is_tracing should still be used to pass tracing tests * nits * some nits to make sure genration works with static cache uncompiled * fix sdpa * fix FA2 for both static and dynamic in a better way? * style * fix-copies * fix fix copies * fix sequential beam searcg * style * use `keys_to_ignore` * nit * correct dtype inference when init * :( the fix for FA2 is still not optimal to investigate! * styling * nits * nit * this might work better * add comment * Update src/transformers/models/llama/modeling_llama.py * "position_ids" -> "cache_position" * style * nit * Remove changes that should no be propagatted just yet * Apply suggestions from code review * Styling * make sure we raise an errir for static cache with FA2 enabled * move to the bottom of the signature * style * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py * nit in the name --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -143,7 +143,7 @@ class CacheTest(unittest.TestCase):
|
||||
mha_config = LlamaConfig(num_attention_heads=32)
|
||||
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mha_static_cache.update(
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
||||
@@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase):
|
||||
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
||||
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = gqa_static_cache.update(
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
||||
@@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase):
|
||||
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
||||
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mqa_static_cache.update(
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
Reference in New Issue
Block a user