Support input_embeds in torch exportable decoders (#39836)

* Support input_embeds in torch exportable decoders

* Hybrid cache update

* Manually change some callsites

* AI changes the rest of the call sites

* Make either input_ids/inputs_embeds mandatory

* Clean up

* Ruff check --fix

* Fix test

* pr review

* Revert config/generation_config changes

* Ruff check
This commit is contained in:
Jack
2025-08-07 01:51:31 -07:00
committed by GitHub
parent cdeaad96b7
commit 6121e9e46c
11 changed files with 325 additions and 85 deletions

View File

@@ -841,8 +841,24 @@ class CacheExportIntegrationTest(unittest.TestCase):
model.eval()
max_batch_size = 1
max_cache_len = 23
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
exported_program = exportable_module.export()
# Set generation config on the model for the hybrid cache model
from transformers.generation.configuration_utils import GenerationConfig
model.generation_config = GenerationConfig(
use_cache=True,
cache_implementation="hybrid",
max_length=max_cache_len,
cache_config={
"batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": model.device,
},
)
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
n_g_key_caches = n_g_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("key_cache"):