Make cache traceable (#35873)

simply make cache traceable
This commit is contained in:
Ilyas Moutawwakil
2025-02-20 09:59:25 +01:00
committed by GitHub
parent 31bb662db1
commit 5e2183f344
3 changed files with 21 additions and 30 deletions

View File

@@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase):
# Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
if buffer_name.startswith("key_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"):
if buffer_name.startswith("value_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1
@@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase):
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
self.assertEqual(responses, EXPECTED_DECODED_TEXT)