Support input_embeds in torch exportable decoders (#39836)
* Support input_embeds in torch exportable decoders * Hybrid cache update * Manually change some callsites * AI changes the rest of the call sites * Make either input_ids/inputs_embeds mandatory * Clean up * Ruff check --fix * Fix test * pr review * Revert config/generation_config changes * Ruff check
This commit is contained in:
@@ -460,7 +460,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user