Support input_embeds in torch exportable decoders (#39836)
* Support input_embeds in torch exportable decoders * Hybrid cache update * Manually change some callsites * AI changes the rest of the call sites * Make either input_ids/inputs_embeds mandatory * Clean up * Ruff check --fix * Fix test * pr review * Revert config/generation_config changes * Ruff check
This commit is contained in:
@@ -460,7 +460,10 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -365,7 +365,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
@@ -389,7 +392,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
|
||||
# Test generation with the exported model
|
||||
prompt = "What is the capital of France?"
|
||||
|
||||
@@ -822,7 +822,10 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
logging.info(f"\nExported program: {exported_program}")
|
||||
|
||||
# Test generation with the exported model
|
||||
|
||||
@@ -353,7 +353,10 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -384,7 +384,10 @@ class OlmoIntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -417,7 +417,10 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -303,7 +303,11 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
strict = version.parse(torch.__version__) != version.parse(
|
||||
"2.7.0"
|
||||
) # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||
exported_program = exportable_module.export(strict=strict)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
strict=strict,
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
@@ -293,7 +293,11 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(strict=strict)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=prompt_token_ids,
|
||||
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
|
||||
strict=strict,
|
||||
)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user