Fix bugs in DynamicCache (#37880)

* Fix bugs in DynamicCache

* Updarte

* Update

* Lint

* lint

* Rename test

* update

* update
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-06-24 13:43:40 -04:00
committed by GitHub
parent 6bdd4ec952
commit 67d36dc1d7
2 changed files with 97 additions and 1 deletions

View File

@@ -626,6 +626,102 @@ class CacheExportIntegrationTest(unittest.TestCase):
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
def test_dynamic_cache_exportability_multiple_run(self):
# When exporting with DynamicCache, you should export two graphs:
# 1. A graph without cache
# 2. A graph with cache
# In the future, we will make improvements to export API to export two graphs
# more seamlessly.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
prompt = "What is the best way to debug python script?"
inputs = tokenizer(prompt, return_tensors="pt")
attention_mask = inputs.attention_mask
input_ids = inputs.input_ids
ep = export_with_dynamic_cache(model, input_ids, attention_mask)
res = ep.module()(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=DynamicCache(),
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
len(
[
x
for x in ep.graph_signature.input_specs
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
]
),
)
res_eager = model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=DynamicCache(),
use_cache=True,
)
past_key_values_eager = res_eager.past_key_values
past_key_values = res.past_key_values
shapes = torch.export.ShapesCollection()
dyn = torch.export.Dim("seq", max=512)
for ix in range(len(past_key_values.key_cache)):
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
ep_second = torch.export.export(
model,
(),
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
},
strict=False,
dynamic_shapes=shapes,
)
res_export = ep_second.module()(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
# It should work with variable len
res_export_2 = ep_second.module()(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=res_export.past_key_values,
use_cache=True,
)
res_eager = model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values_eager,
use_cache=True,
)
res_eager_2 = model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=res_eager.past_key_values,
use_cache=True,
)
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`