Support input_embeds in torch exportable decoders (#39836)
* Support input_embeds in torch exportable decoders * Hybrid cache update * Manually change some callsites * AI changes the rest of the call sites * Make either input_ids/inputs_embeds mandatory * Clean up * Ruff check --fix * Fix test * pr review * Revert config/generation_config changes * Ruff check
This commit is contained in:
@@ -365,7 +365,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
@@ -389,7 +392,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
|
||||
# Test generation with the exported model
|
||||
prompt = "What is the capital of France?"
|
||||
|
||||
Reference in New Issue
Block a user