Support tracable dynamicKVcache (#36311)

* Support tracable dynamicKVcache

* Fix lint

* More fine grained test

* Lint

* Update

* Update

* Fix up

* Apply suggestions from code review

* Update src/transformers/cache_utils.py

* Update tests/utils/test_cache_utils.py

* Apply suggestions from code review

* Update

* Change error message

* Rename

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

---------

Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-03-19 12:52:30 -04:00
committed by GitHub
parent 63c3116530
commit f39f4960f3
3 changed files with 116 additions and 4 deletions

View File

@@ -174,6 +174,60 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
def test_dynamic_cache_exportability(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
prompt = "What is the best way to debug python script?"
inputs = tokenizer(prompt, return_tensors="pt")
attention_mask = inputs.attention_mask
input_ids = inputs.input_ids
past_key_values = DynamicCache()
ep = torch.export.export(
model,
(),
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
},
strict=False,
)
res = ep.module()(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
len(
[
x
for x in ep.graph_signature.input_specs
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
]
),
)
past_key_values_eager = DynamicCache()
res_eager = model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values_eager,
use_cache=True,
)
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
@slow
@require_read_token
def test_static_cache_exportability(self):