Support input_embeds in torch exportable decoders (#39836)
* Support input_embeds in torch exportable decoders * Hybrid cache update * Manually change some callsites * AI changes the rest of the call sites * Make either input_ids/inputs_embeds mandatory * Clean up * Ruff check --fix * Fix test * pr review * Revert config/generation_config changes * Ruff check
This commit is contained in:
@@ -841,8 +841,24 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
model.eval()
|
||||
max_batch_size = 1
|
||||
max_cache_len = 23
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
|
||||
exported_program = exportable_module.export()
|
||||
# Set generation config on the model for the hybrid cache model
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
|
||||
model.generation_config = GenerationConfig(
|
||||
use_cache=True,
|
||||
cache_implementation="hybrid",
|
||||
max_length=max_cache_len,
|
||||
cache_config={
|
||||
"batch_size": max_batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": model.device,
|
||||
},
|
||||
)
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
|
||||
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
|
||||
)
|
||||
n_g_key_caches = n_g_value_caches = 0
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("key_cache"):
|
||||
|
||||
Reference in New Issue
Block a user