Support input_embeds in torch exportable decoders (#39836)

* Support input_embeds in torch exportable decoders

* Hybrid cache update

* Manually change some callsites

* AI changes the rest of the call sites

* Make either input_ids/inputs_embeds mandatory

* Clean up

* Ruff check --fix

* Fix test

* pr review

* Revert config/generation_config changes

* Ruff check
This commit is contained in:
Jack
2025-08-07 01:51:31 -07:00
committed by GitHub
parent cdeaad96b7
commit 6121e9e46c
11 changed files with 325 additions and 85 deletions

View File

@@ -822,7 +822,10 @@ class Gemma3IntegrationTest(unittest.TestCase):
# Export + HybridCache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
logging.info(f"\nExported program: {exported_program}")
# Test generation with the exported model