Generate: end-to-end compilation (#30788)

* mvp

* added test (a few models need fixes)

* fix a few test cases

* test nits

* harder test 😈

* revert changes in stablelm

* test with improved condition

* add todo

* tmp commit

* merged with main

* nits

* add todo

* final corrections

* add docs for generation compilation

* docs nits

* add  tip

* PR suggestions

* add more details to the compilation docs

* fix cache positions

* cache is now init in generate; update docs

* tag test as flaky

* docs

* post rebase make fixup and other nits

* remove unintended changes

* whisper (encoder-decoder) not supported

* move token default updates to ; add tests for token defaults

* push changes

* manual rebase

* chameleon doesn't support this

* fix test_static_cache_mha_mqa_gqa (broken in another PR)

* docs: dynamic is better with end-to-end compilation
This commit is contained in:
Joao Gante
2024-07-29 10:52:13 +01:00
committed by GitHub
parent 49928892d6
commit 7ffe25f2b9
11 changed files with 285 additions and 103 deletions

View File

@@ -143,7 +143,7 @@ class CacheTest(unittest.TestCase):
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
@@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase):
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
@@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase):
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))