Allow override inputs to export recipe (#37508)

Add option to specify dynamic shapes during export

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang
2025-04-30 01:19:27 -07:00
committed by GitHub
parent 481de7204c
commit a57274466f
2 changed files with 90 additions and 16 deletions

View File

@@ -25,7 +25,6 @@ from transformers.testing_utils import (
is_torch_available,
require_gptq,
require_non_xpu,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
@@ -693,8 +692,6 @@ class CacheExportIntegrationTest(unittest.TestCase):
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
@slow
@require_read_token
def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
@@ -709,8 +706,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
batch_size = 1
max_cache_len = 1234
model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
model_id,
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
@@ -748,3 +746,59 @@ class CacheExportIntegrationTest(unittest.TestCase):
n_static_value_caches = n_static_value_caches + 1
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
# Export with dynamic shapes using Dim.AUTO
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
exported_program = convert_and_export_with_cache(
model,
example_input_ids=input_ids,
dynamic_shapes=dynamic_shapes,
strict=False,
)
def test_hybrid_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
if not is_torch_greater_or_equal("2.6"):
self.skipTest(reason="This test requires torch >= 2.6 to run.")
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
set_seed(0)
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_id)
model.eval()
self.assertEqual(model.config.use_cache, True)
self.assertEqual(model.config.cache_implementation, "hybrid")
# Export + HybridCache
model.eval()
max_batch_size = 1
max_cache_len = 23
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
exported_program = exportable_module.export()
n_g_key_caches = n_g_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("key_cache"):
self.assertTrue(buffer.shape[0] == max_batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_g_key_caches = n_g_key_caches + 1
if buffer_name.startswith("value_cache"):
self.assertTrue(buffer.shape[0] == max_batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_g_value_caches = n_g_value_caches + 1
self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
# Export with dynamic shapes using Dim.AUTO
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
exported_program = exportable_module.export(
input_ids=input_ids,
dynamic_shapes=dynamic_shapes,
strict=False,
)