Fix bugs in DynamicCache (#37880)
* Fix bugs in DynamicCache * Updarte * Update * Lint * lint * Rename test * update * update
This commit is contained in:
committed by
GitHub
parent
6bdd4ec952
commit
67d36dc1d7
@@ -626,6 +626,102 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
|
||||
def test_dynamic_cache_exportability_multiple_run(self):
|
||||
# When exporting with DynamicCache, you should export two graphs:
|
||||
# 1. A graph without cache
|
||||
# 2. A graph with cache
|
||||
# In the future, we will make improvements to export API to export two graphs
|
||||
# more seamlessly.
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
model = model.eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
prompt = "What is the best way to debug python script?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
attention_mask = inputs.attention_mask
|
||||
input_ids = inputs.input_ids
|
||||
|
||||
ep = export_with_dynamic_cache(model, input_ids, attention_mask)
|
||||
res = ep.module()(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=DynamicCache(),
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
len(
|
||||
[
|
||||
x
|
||||
for x in ep.graph_signature.input_specs
|
||||
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
res_eager = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=DynamicCache(),
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values_eager = res_eager.past_key_values
|
||||
past_key_values = res.past_key_values
|
||||
|
||||
shapes = torch.export.ShapesCollection()
|
||||
dyn = torch.export.Dim("seq", max=512)
|
||||
|
||||
for ix in range(len(past_key_values.key_cache)):
|
||||
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
|
||||
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
|
||||
|
||||
ep_second = torch.export.export(
|
||||
model,
|
||||
(),
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": True,
|
||||
},
|
||||
strict=False,
|
||||
dynamic_shapes=shapes,
|
||||
)
|
||||
res_export = ep_second.module()(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
# It should work with variable len
|
||||
res_export_2 = ep_second.module()(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=res_export.past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
res_eager = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values_eager,
|
||||
use_cache=True,
|
||||
)
|
||||
res_eager_2 = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=res_eager.past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2))
|
||||
|
||||
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
|
||||
def test_static_cache_exportability(self):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
|
||||
Reference in New Issue
Block a user