Allow override inputs to export recipe (#37508)
Add option to specify dynamic shapes during export Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
@@ -25,7 +25,6 @@ from transformers.testing_utils import (
|
||||
is_torch_available,
|
||||
require_gptq,
|
||||
require_non_xpu,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
@@ -693,8 +692,6 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_static_cache_exportability(self):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
@@ -709,8 +706,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||
batch_size = 1
|
||||
max_cache_len = 1234
|
||||
model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2b",
|
||||
model_id,
|
||||
device_map=device,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=attn_implementation,
|
||||
@@ -748,3 +746,59 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
n_static_value_caches = n_static_value_caches + 1
|
||||
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
|
||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||
|
||||
# Export with dynamic shapes using Dim.AUTO
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
|
||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
|
||||
exported_program = convert_and_export_with_cache(
|
||||
model,
|
||||
example_input_ids=input_ids,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def test_hybrid_cache_exportability(self):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
"""
|
||||
if not is_torch_greater_or_equal("2.6"):
|
||||
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
||||
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
set_seed(0)
|
||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
model.eval()
|
||||
self.assertEqual(model.config.use_cache, True)
|
||||
self.assertEqual(model.config.cache_implementation, "hybrid")
|
||||
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
max_batch_size = 1
|
||||
max_cache_len = 23
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
|
||||
exported_program = exportable_module.export()
|
||||
n_g_key_caches = n_g_value_caches = 0
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("key_cache"):
|
||||
self.assertTrue(buffer.shape[0] == max_batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_g_key_caches = n_g_key_caches + 1
|
||||
if buffer_name.startswith("value_cache"):
|
||||
self.assertTrue(buffer.shape[0] == max_batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_g_value_caches = n_g_value_caches + 1
|
||||
self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
|
||||
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
|
||||
|
||||
# Export with dynamic shapes using Dim.AUTO
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
|
||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=input_ids,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user