Support input_embeds in torch exportable decoders (#39836)

* Support input_embeds in torch exportable decoders

* Hybrid cache update

* Manually change some callsites

* AI changes the rest of the call sites

* Make either input_ids/inputs_embeds mandatory

* Clean up

* Ruff check --fix

* Fix test

* pr review

* Revert config/generation_config changes

* Ruff check
This commit is contained in:
Jack
2025-08-07 01:51:31 -07:00
committed by GitHub
parent cdeaad96b7
commit 6121e9e46c
11 changed files with 325 additions and 85 deletions

View File

@@ -198,34 +198,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
def __init__( def __init__(
self, self,
model: PreTrainedModel, model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
): ):
""" """
Initializes the exportable module with `HybridCache`. Initializes the exportable module with `HybridCache`.
Args: Args:
model (`PreTrainedModel`): The pretrained model to wrap. model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.
Raises: Raises:
ValueError: If the model is configured with a unsupported cache implementation. ValueError: If the model is configured with a unsupported cache implementation.
""" """
super().__init__() super().__init__()
if not hasattr(model.config, "use_cache") or model.config.use_cache is False: config = model.config.get_text_config()
_generation_config = model.generation_config
if not hasattr(config, "use_cache") or config.use_cache is False:
raise ValueError("The model must have caching enabled to be performant.") raise ValueError("The model must have caching enabled to be performant.")
if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None: if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None:
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) self.model = TorchExportableModuleWithHybridCache(model)
else: else:
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null, # If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
# there is only 1 type of layers, so export will use `StaticCache` by default. # there is only 1 type of layers, so export will use `StaticCache` by default.
logging.info( logging.info(
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
) )
self.model = TorchExportableModuleWithStaticCache(model, max_batch_size, max_cache_len) self.model = TorchExportableModuleWithStaticCache(model)
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
@@ -233,24 +232,31 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: Optional[torch.Tensor] = None,
cache_position: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass of the module, which is compatible with the ExecuTorch llm runner. Forward pass of the module, which is compatible with the ExecuTorch llm runner.
Args: Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns: Returns:
torch.Tensor: Logits output from the model. torch.Tensor: Logits output from the model.
""" """
return self.model.forward(input_ids, cache_position) return self.model.forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
)
def export( def export(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None, cache_position: Optional[torch.Tensor] = None,
dynamic_shapes: Optional[dict] = None, dynamic_shapes: Optional[dict] = None,
strict: Optional[bool] = None, strict: Optional[bool] = None,
@@ -260,14 +266,49 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
Args: Args:
input_ids (`Optional[torch.Tensor]`): input_ids (`Optional[torch.Tensor]`):
Tensor representing current input token id to the module. If not provided, a default tensor will be used. Tensor representing current input token id to the module. Must specify either this or inputs_embeds.
inputs_embeds (`Optional[torch.Tensor]`):
Tensor representing current input embeddings to the module. Must specify either this or input_ids.
cache_position (`Optional[torch.Tensor]`): cache_position (`Optional[torch.Tensor]`):
Tensor representing current input position in the cache. If not provided, a default tensor will be used. Tensor representing current input position in the cache. If not provided, a default tensor will be used.
dynamic_shapes (`Optional[dict]`): dynamic_shapes (`Optional[dict]`):
Dynamic shapes to use for export if specified. Dynamic shapes to use for export if specified.
strict(`Optional[bool]`): strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`. Flag to instruct `torch.export` to use `torchdynamo`.
Returns:
torch.export.ExportedProgram: The exported program that can be used for inference.
Examples:
Export with input_ids:
```python
# Prepare inputs
input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device)
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device)
# Export
exported = exportable_module.export(
input_ids=input_ids,
cache_position=cache_position
)
```
Export with inputs_embeds:
```python
# Prepare embeddings
inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768
cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device)
# Export
exported = exportable_module.export(
inputs_embeds=inputs_embeds,
cache_position=cache_position
)
```
""" """
if not (input_ids is None) ^ (inputs_embeds is None):
raise ValueError("Need to specify either input_ids or inputs_embeds.")
if hasattr(self.model, "base_model_prefix"): if hasattr(self.model, "base_model_prefix"):
base = getattr(self.model, self.model.base_model_prefix, self.model) base = getattr(self.model, self.model.base_model_prefix, self.model)
model_device = base.device model_device = base.device
@@ -279,20 +320,29 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default." "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
) )
example_input_ids = ( if input_ids is not None:
input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device) input_kwargs = {
) "input_ids": input_ids,
example_cache_position = ( "cache_position": cache_position
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) if cache_position is not None
) else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device),
}
else: # inputs_embeds
input_kwargs = {
"inputs_embeds": inputs_embeds,
"cache_position": cache_position
if cache_position is not None
else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device),
}
exported_program = torch.export.export( exported_program = torch.export.export(
self.model, self.model,
args=(example_input_ids, example_cache_position), args=(),
kwargs={}, kwargs=input_kwargs,
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else True, strict=strict if strict is not None else True,
) )
return exported_program return exported_program
@staticmethod @staticmethod
@@ -341,7 +391,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
# Forward pass # Forward pass
_ = exported_module(curr_input_ids, curr_cache_position) _ = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
curr_position += 1 curr_position += 1
# Generate new tokens # Generate new tokens
@@ -351,7 +401,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
# Forward pass to get next token logits # Forward pass to get next token logits
outputs = exported_module(curr_input_ids, curr_cache_position) outputs = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
# Get the next token ID # Get the next token ID
if do_sample: if do_sample:
@@ -418,8 +468,6 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
def __init__( def __init__(
self, self,
model: PreTrainedModel, model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
): ):
""" """
Initializes the wrapper module with the pretrained model. Initializes the wrapper module with the pretrained model.
@@ -434,27 +482,31 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
""" """
super().__init__() super().__init__()
# Sanity checks config = model.config.get_text_config()
if model.generation_config is None: generation_config = model.generation_config
# Use default generation config if not specified
model.generation_config = GenerationConfig(
use_cache=model.config.use_cache,
cache_implementation="static",
max_length=max_cache_len,
cache_config={
"batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": "cpu",
},
)
if not model.generation_config.use_cache: # Sanity checks
if generation_config is None:
raise AssertionError(
"The model must have a generation config to be exported with static caching. "
"Please set `generation_config` in `model`."
)
if "batch_size" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a batch_size in its cache_config. "
'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)'
)
if "max_cache_len" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a max_cache_len in its cache_config. "
'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)'
)
if not generation_config.use_cache:
raise AssertionError( raise AssertionError(
"The model must have caching enabled to be exported with static caching. " "The model must have caching enabled to be exported with static caching. "
"Please set `generation_config.use_cache=True`." "Please set `generation_config.use_cache=True`."
) )
if generation_config.cache_implementation != "static":
if model.generation_config.cache_implementation != "static":
raise AssertionError( raise AssertionError(
"The model must use a 'static' caching implementation to be exported with static caching. " "The model must use a 'static' caching implementation to be exported with static caching. "
"Please set `generation_config.cache_implementation='static'`." "Please set `generation_config.cache_implementation='static'`."
@@ -462,22 +514,29 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
self.model = model self.model = model
self.static_cache = StaticCache( self.static_cache = StaticCache(
config=self.model.config, config=config,
max_batch_size=self.model.generation_config.cache_config.get("batch_size"), max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"), max_cache_len=generation_config.cache_config.get("max_cache_len"),
device=self.model.generation_config.cache_config.get("device"), device=generation_config.cache_config.get("device"),
dtype=self.model.dtype, dtype=self.model.dtype,
) )
for i in range(len(self.static_cache)): for i in range(len(self.static_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
):
""" """
Forward pass of the module, which is compatible with the ExecuTorch runtime. Forward pass of the module, which is compatible with the ExecuTorch runtime.
Args: Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns: Returns:
@@ -493,15 +552,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`, The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
""" """
_, seqlen = input_ids.shape
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache past_key_values = self.static_cache
outs = self.model( outs = self.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=None, inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
attention_mask=None,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
@@ -576,33 +633,45 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
def __init__( def __init__(
self, self,
model: PreTrainedModel, model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
): ):
""" """
Initializes the exportable module with `HybridCache`. Initializes the exportable module with `HybridCache`.
Args: Args:
model (`PreTrainedModel`): The pretrained model to wrap. model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.
Raises: Raises:
AssertionError: If the model doesn't have the expected configuration for HybridCache. AssertionError: If the model doesn't have the expected configuration for HybridCache.
""" """
super().__init__() super().__init__()
self.model = model self.model = model
config = model.config.get_text_config()
generation_config = model.generation_config
# Verify the model is configured for HybridCache if generation_config is None:
if not self.model.config.use_cache: raise AssertionError(
raise AssertionError("Model must have caching enabled") "The model must have a generation config to be exported with static caching. "
"Please set `generation_config` in `model`."
)
if "batch_size" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a batch_size in its cache_config. "
'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)'
)
if "max_cache_len" not in generation_config.cache_config:
raise ValueError(
"The model's generation config must specify a max_cache_len in its cache_config. "
'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)'
)
if not config.use_cache:
raise AssertionError("Model must have caching enabled.")
# Initialize the HybridCache # Initialize the HybridCache
self.cache = HybridCache( self.cache = HybridCache(
config=self.model.config, config=config,
max_batch_size=max_batch_size, max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=max_cache_len, max_cache_len=generation_config.cache_config.get("max_cache_len"),
device=self.model.device, device=generation_config.cache_config.get("device"),
dtype=self.model.dtype, dtype=self.model.dtype,
) )
@@ -613,32 +682,29 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: Optional[torch.LongTensor] = None,
cache_position: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass of the module, which is compatible with the ExecuTorch llm runner. Forward pass of the module, which is compatible with the ExecuTorch llm runner.
Args: Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
inputs_embeds (`Optional[torch.Tensor]`): Tensor representing current input embeddings to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns: Returns:
torch.Tensor: Logits output from the model. torch.Tensor: Logits output from the model.
""" """
batch_size = input_ids.shape[0]
# Generate position_ids from cache_position
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
# Forward pass with the model # Forward pass with the model
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
attention_mask=None, attention_mask=None,
position_ids=position_ids,
past_key_values=self.cache, past_key_values=self.cache,
use_cache=True, use_cache=True,
cache_position=cache_position,
) )
# Return only the logits to simplify the export # Return only the logits to simplify the export
@@ -692,8 +758,8 @@ def convert_and_export_with_cache(
if is_torch_greater_or_equal("2.6.0"): if is_torch_greater_or_equal("2.6.0"):
exported_program = torch.export.export( exported_program = torch.export.export(
TorchExportableModuleWithStaticCache(model), TorchExportableModuleWithStaticCache(model),
args=(example_input_ids, example_cache_position), args=(),
kwargs={}, kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else True, strict=strict if strict is not None else True,
) )
@@ -710,8 +776,8 @@ def convert_and_export_with_cache(
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
exported_program = torch.export._trace._export( exported_program = torch.export._trace._export(
TorchExportableModuleWithStaticCache(model), TorchExportableModuleWithStaticCache(model),
args=(example_input_ids,), args=(),
kwargs={"cache_position": example_cache_position}, kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
pre_dispatch=False, pre_dispatch=False,
strict=True, strict=True,
) )

View File

@@ -460,7 +460,10 @@ class GemmaIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export() exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@@ -365,7 +365,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export() exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )
@@ -389,7 +392,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
# Export + HybridCache # Export + HybridCache
model.eval() model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export() exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
# Test generation with the exported model # Test generation with the exported model
prompt = "What is the capital of France?" prompt = "What is the capital of France?"

View File

@@ -822,7 +822,10 @@ class Gemma3IntegrationTest(unittest.TestCase):
# Export + HybridCache # Export + HybridCache
model.eval() model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export() exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
logging.info(f"\nExported program: {exported_program}") logging.info(f"\nExported program: {exported_program}")
# Test generation with the exported model # Test generation with the exported model

View File

@@ -353,7 +353,10 @@ class LlamaIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export() exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@@ -384,7 +384,10 @@ class OlmoIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export() exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@@ -417,7 +417,10 @@ class Phi3IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export() exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@@ -303,7 +303,11 @@ class Qwen2IntegrationTest(unittest.TestCase):
strict = version.parse(torch.__version__) != version.parse( strict = version.parse(torch.__version__) != version.parse(
"2.7.0" "2.7.0"
) # Due to https://github.com/pytorch/pytorch/issues/150994 ) # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = exportable_module.export(strict=strict) exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
strict=strict,
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@@ -293,7 +293,11 @@ class Qwen3IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(strict=strict) exported_program = exportable_module.export(
input_ids=prompt_token_ids,
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
strict=strict,
)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

129
tests/test_executorch.py Normal file
View File

@@ -0,0 +1,129 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoModelForCausalLM, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
TorchExportableModuleWithHybridCache,
TorchExportableModuleWithStaticCache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
from transformers.testing_utils import require_torch
@require_torch
class ExecutorchTest(unittest.TestCase):
def setUp(self):
if not is_torch_greater_or_equal_than_2_3:
self.skipTest("torch >= 2.3 is required")
set_seed(0)
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
self.model.eval()
# Create generation config with static cache for the model
self.model.generation_config = GenerationConfig(
use_cache=True,
cache_implementation="static",
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
)
self.input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
self.inputs_embeds = torch.randn(1, 3, self.model.config.hidden_size)
self.cache_position = torch.arange(3, dtype=torch.long)
def test_static_cache_module_forward(self):
"""Test TorchExportableModuleWithStaticCache forward with both input types"""
generation_config = GenerationConfig(
use_cache=True,
cache_implementation="static",
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
)
# Set generation config on model
self.model.generation_config = generation_config
module = TorchExportableModuleWithStaticCache(self.model)
# Test with input_ids
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
# Test with inputs_embeds
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
def test_hybrid_cache_module_forward(self):
"""Test TorchExportableModuleWithHybridCache forward with both input types"""
config = self.model.config
config.sliding_window = 16
config.layer_types = ["full_attention"] * config.num_hidden_layers
generation_config = GenerationConfig(
use_cache=True,
cache_implementation="hybrid",
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
)
# Set generation config on model
self.model.generation_config = generation_config
module = TorchExportableModuleWithHybridCache(self.model)
# Test with input_ids
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
# Test with inputs_embeds
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
def test_decoder_only_lm_export_validation(self):
"""Test TorchExportableModuleForDecoderOnlyLM export validation"""
module = TorchExportableModuleForDecoderOnlyLM(self.model)
# Should fail with both input_ids and inputs_embeds
with self.assertRaises(ValueError):
module.export(input_ids=self.input_ids, inputs_embeds=self.inputs_embeds)
# Should fail with neither
with self.assertRaises(ValueError):
module.export()
def test_decoder_only_lm_export(self):
"""Test TorchExportableModuleForDecoderOnlyLM export with both input types"""
module = TorchExportableModuleForDecoderOnlyLM(self.model)
# Test export with input_ids
exported_program_ids = module.export(input_ids=self.input_ids, cache_position=self.cache_position)
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
exported_output_ids = exported_program_ids.module()(
input_ids=self.input_ids, cache_position=self.cache_position
)
torch.testing.assert_close(eager_output_ids, exported_output_ids, atol=1e-4, rtol=1e-4)
# Test export with inputs_embeds
exported_program_embeds = module.export(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
exported_output_embeds = exported_program_embeds.module()(
inputs_embeds=self.inputs_embeds, cache_position=self.cache_position
)
torch.testing.assert_close(eager_output_embeds, exported_output_embeds, atol=1e-4, rtol=1e-4)

View File

@@ -841,8 +841,24 @@ class CacheExportIntegrationTest(unittest.TestCase):
model.eval() model.eval()
max_batch_size = 1 max_batch_size = 1
max_cache_len = 23 max_cache_len = 23
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len) # Set generation config on the model for the hybrid cache model
exported_program = exportable_module.export() from transformers.generation.configuration_utils import GenerationConfig
model.generation_config = GenerationConfig(
use_cache=True,
cache_implementation="hybrid",
max_length=max_cache_len,
cache_config={
"batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": model.device,
},
)
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
n_g_key_caches = n_g_value_caches = 0 n_g_key_caches = n_g_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers(): for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("key_cache"): if buffer_name.startswith("key_cache"):