Support tracable dynamicKVcache (#36311)
* Support tracable dynamicKVcache * Fix lint * More fine grained test * Lint * Update * Update * Fix up * Apply suggestions from code review * Update src/transformers/cache_utils.py * Update tests/utils/test_cache_utils.py * Apply suggestions from code review * Update * Change error message * Rename * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
63c3116530
commit
f39f4960f3
@@ -174,6 +174,60 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
def test_dynamic_cache_exportability(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
model = model.eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
prompt = "What is the best way to debug python script?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
attention_mask = inputs.attention_mask
|
||||
input_ids = inputs.input_ids
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
(),
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": True,
|
||||
},
|
||||
strict=False,
|
||||
)
|
||||
res = ep.module()(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
len(
|
||||
[
|
||||
x
|
||||
for x in ep.graph_signature.input_specs
|
||||
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
past_key_values_eager = DynamicCache()
|
||||
res_eager = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values_eager,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
|
||||
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2))
|
||||
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_static_cache_exportability(self):
|
||||
|
||||
Reference in New Issue
Block a user