diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 10a8689388..04ccc6f7ef 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -695,7 +695,7 @@ def _flatten_dynamic_cache_for_fx(cache, spec): "key_cache": getattr(cache, "key_cache"), "value_cache": getattr(cache, "value_cache"), } - return torch.utils._pytree.tree_flatten(dictionary)[0] + return torch.fx._pytree._dict_flatten_spec(dictionary, spec) if is_torch_greater_or_equal("2.3"): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index b1f41153e2..8c864f9b64 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -626,6 +626,102 @@ class CacheExportIntegrationTest(unittest.TestCase): for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): self.assertTrue(torch.allclose(v1, v2)) + def test_dynamic_cache_exportability_multiple_run(self): + # When exporting with DynamicCache, you should export two graphs: + # 1. A graph without cache + # 2. A graph with cache + # In the future, we will make improvements to export API to export two graphs + # more seamlessly. + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + prompt = "What is the best way to debug python script?" + inputs = tokenizer(prompt, return_tensors="pt") + attention_mask = inputs.attention_mask + input_ids = inputs.input_ids + + ep = export_with_dynamic_cache(model, input_ids, attention_mask) + res = ep.module()( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=DynamicCache(), + use_cache=True, + ) + self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) + self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs)) + self.assertEqual( + 3, + len( + [ + x + for x in ep.graph_signature.input_specs + if x.kind == torch.export.graph_signature.InputKind.USER_INPUT + ] + ), + ) + + res_eager = model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=DynamicCache(), + use_cache=True, + ) + past_key_values_eager = res_eager.past_key_values + past_key_values = res.past_key_values + + shapes = torch.export.ShapesCollection() + dyn = torch.export.Dim("seq", max=512) + + for ix in range(len(past_key_values.key_cache)): + shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None) + shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None) + + ep_second = torch.export.export( + model, + (), + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + }, + strict=False, + dynamic_shapes=shapes, + ) + res_export = ep_second.module()( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + # It should work with variable len + res_export_2 = ep_second.module()( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=res_export.past_key_values, + use_cache=True, + ) + + res_eager = model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values_eager, + use_cache=True, + ) + res_eager_2 = model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=res_eager.past_key_values, + use_cache=True, + ) + + for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache): + self.assertTrue(torch.allclose(k1, k2)) + + for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache): + self.assertTrue(torch.allclose(v1, v2)) + def test_static_cache_exportability(self): """ Tests that static cache works with `torch.export()`