Fix bugs in DynamicCache (#37880)
* Fix bugs in DynamicCache * Updarte * Update * Lint * lint * Rename test * update * update
This commit is contained in:
committed by
GitHub
parent
6bdd4ec952
commit
67d36dc1d7
@@ -695,7 +695,7 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
|
|||||||
"key_cache": getattr(cache, "key_cache"),
|
"key_cache": getattr(cache, "key_cache"),
|
||||||
"value_cache": getattr(cache, "value_cache"),
|
"value_cache": getattr(cache, "value_cache"),
|
||||||
}
|
}
|
||||||
return torch.utils._pytree.tree_flatten(dictionary)[0]
|
return torch.fx._pytree._dict_flatten_spec(dictionary, spec)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_greater_or_equal("2.3"):
|
if is_torch_greater_or_equal("2.3"):
|
||||||
|
|||||||
@@ -626,6 +626,102 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||||
self.assertTrue(torch.allclose(v1, v2))
|
self.assertTrue(torch.allclose(v1, v2))
|
||||||
|
|
||||||
|
def test_dynamic_cache_exportability_multiple_run(self):
|
||||||
|
# When exporting with DynamicCache, you should export two graphs:
|
||||||
|
# 1. A graph without cache
|
||||||
|
# 2. A graph with cache
|
||||||
|
# In the future, we will make improvements to export API to export two graphs
|
||||||
|
# more seamlessly.
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||||
|
model = model.eval()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||||
|
prompt = "What is the best way to debug python script?"
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
attention_mask = inputs.attention_mask
|
||||||
|
input_ids = inputs.input_ids
|
||||||
|
|
||||||
|
ep = export_with_dynamic_cache(model, input_ids, attention_mask)
|
||||||
|
res = ep.module()(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=DynamicCache(),
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||||
|
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||||
|
self.assertEqual(
|
||||||
|
3,
|
||||||
|
len(
|
||||||
|
[
|
||||||
|
x
|
||||||
|
for x in ep.graph_signature.input_specs
|
||||||
|
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
res_eager = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=DynamicCache(),
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values_eager = res_eager.past_key_values
|
||||||
|
past_key_values = res.past_key_values
|
||||||
|
|
||||||
|
shapes = torch.export.ShapesCollection()
|
||||||
|
dyn = torch.export.Dim("seq", max=512)
|
||||||
|
|
||||||
|
for ix in range(len(past_key_values.key_cache)):
|
||||||
|
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
|
||||||
|
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
|
||||||
|
|
||||||
|
ep_second = torch.export.export(
|
||||||
|
model,
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": True,
|
||||||
|
},
|
||||||
|
strict=False,
|
||||||
|
dynamic_shapes=shapes,
|
||||||
|
)
|
||||||
|
res_export = ep_second.module()(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
# It should work with variable len
|
||||||
|
res_export_2 = ep_second.module()(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=res_export.past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
res_eager = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values_eager,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
res_eager_2 = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=res_eager.past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
|
||||||
|
self.assertTrue(torch.allclose(k1, k2))
|
||||||
|
|
||||||
|
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
|
||||||
|
self.assertTrue(torch.allclose(v1, v2))
|
||||||
|
|
||||||
def test_static_cache_exportability(self):
|
def test_static_cache_exportability(self):
|
||||||
"""
|
"""
|
||||||
Tests that static cache works with `torch.export()`
|
Tests that static cache works with `torch.export()`
|
||||||
|
|||||||
Reference in New Issue
Block a user