Export SmolvLM (#39614)

Export SmolVLM for ExecuTorch
This commit is contained in:
Guang Yang
2025-08-05 07:20:23 -07:00
committed by GitHub
parent c430047602
commit d2ae766836
5 changed files with 282 additions and 18 deletions

View File

@@ -27,6 +27,167 @@ from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
# Add this to src/transformers/integrations/executorch.py
class TorchExportableModuleForVLM:
"""
A wrapper class for exporting Vision-Language Models (VLMs) like SmolVLM2 for ExecuTorch.
This class handles the export of three main components:
1. Vision encoder (processes images to visual features)
2. Connector/projector (maps visual features to text embedding space)
3. Text decoder (generates text from combined visual and text tokens)
"""
def __init__(self, model, max_batch_size: int = 1, max_cache_len: int = 1024):
"""
Initialize the exportable VLM module.
Args:
model: The VLM (e.g. SmolVLM) model instance
max_batch_size: Maximum batch size. Always 1 for ExecuTorch
max_cache_len: Maximum cache length for text generation
"""
self.model = model
self.max_batch_size = max_batch_size
self.max_cache_len = max_cache_len
self.config = model.config
# Extract individual components
self.vision_encoder = model.model.vision_model
self.connector = model.model.connector
self.text_decoder = model.model.text_model
# Store exported programs
self.exported_vision_encoder = None
self.exported_connector = None
self.exported_text_decoder = None
def export_vision_encoder(self):
"""Export the vision encoder component."""
self.vision_encoder.eval()
# Create example input
pixel_values = torch.randn(1, 3, 384, 384, dtype=torch.float32)
# Define dynamic shapes
dynamic_shapes = {
"pixel_values": {
2: torch.export.Dim.AUTO,
3: torch.export.Dim.AUTO,
}
}
self.exported_vision_encoder = torch.export.export(
self.vision_encoder,
args=(pixel_values,),
dynamic_shapes=dynamic_shapes,
strict=False,
)
return self.exported_vision_encoder
def export_connector(self):
"""Export the connector component."""
self.connector.eval()
# Vision encoder output shape: [batch_size, num_patches, vision_hidden_size]
vision_hidden_size = self.config.vision_config.hidden_size
image_size = self.config.vision_config.image_size
patch_size = self.config.vision_config.patch_size
patches_per_dim = image_size // patch_size
num_patches = patches_per_dim * patches_per_dim
image_hidden_states = torch.randn(1, num_patches, vision_hidden_size, dtype=torch.float32)
# Define dynamic shapes - static batch_size=1, dynamic num_patches
dynamic_shapes = {"image_hidden_states": {1: torch.export.Dim.AUTO}}
# Export the connector using torch.export
self.exported_connector = torch.export.export(
self.connector,
args=(image_hidden_states,),
dynamic_shapes=dynamic_shapes,
strict=False,
)
return self.exported_connector
def export_text_decoder(self):
"""Export the text decoder component."""
# Create text decoder exportable wrapper
self.exportable_text_decoder = TorchExportableModuleForDecoderOnlyLM(
model=self.text_decoder,
max_batch_size=self.max_batch_size,
max_cache_len=self.max_cache_len,
)
# Use the existing text decoder exportable wrapper
seq_length = 3
input_ids = torch.zeros((1, seq_length), dtype=torch.long)
cache_position = torch.arange(seq_length, dtype=torch.long)
max_seq_length = min(self.max_cache_len, self.config.text_config.max_position_embeddings)
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_length - 1)
dynamic_shapes = {
"input_ids": {1: seq_len_dim},
"cache_position": {0: seq_len_dim},
}
self.exported_text_decoder = self.exportable_text_decoder.export(
input_ids=input_ids,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=False,
)
return self.exported_text_decoder
def export(self, **kwargs):
"""Export all components of the VLM model."""
self.export_vision_encoder(**kwargs)
self.export_connector(**kwargs)
self.export_text_decoder(**kwargs)
return {
"vision_encoder": self.exported_vision_encoder,
"connector": self.exported_connector,
"text_decoder": self.exported_text_decoder,
}
def forward(self, pixel_values, input_ids, cache_position):
"""
Simplified forward pass for inference with guaranteed non-null input_ids and cache_position.
Args:
pixel_values: Input images [1, channels, height, width] (optional)
input_ids: Text token IDs [1, seq_len] (required - won't be None)
cache_position: Cache positions [seq_len] (required - won't be None)
Returns:
Output with logits for text generation
"""
pass
def generate(
self, pixel_values=None, input_ids=None, max_new_tokens=50, do_sample=False, temperature=1.0, **kwargs
):
"""
Simplified generate method with guaranteed non-null input_ids.
Args:
pixel_values: Input images [1, channels, height, width] (optional)
input_ids: Initial text tokens [1, seq_len] (required - won't be None)
max_new_tokens: Maximum number of tokens to generate
do_sample: Whether to use sampling or greedy decoding
temperature: Temperature for sampling
Returns:
Generated sequences
"""
pass
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"""
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
@@ -64,7 +225,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
logging.info(
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model)
self.model = TorchExportableModuleWithStaticCache(model, max_batch_size, max_cache_len)
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
@@ -254,7 +415,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
"""
def __init__(self, model: PreTrainedModel):
def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the wrapper module with the pretrained model.
@@ -270,9 +436,16 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
# Sanity checks
if model.generation_config is None:
raise AssertionError(
"The model must have a generation config to be exported with static caching. "
"Please set `generation_config`."
# Use default generation config if not specified
model.generation_config = GenerationConfig(
use_cache=model.config.use_cache,
cache_implementation="static",
max_length=max_cache_len,
cache_config={
"batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": "cpu",
},
)
if not model.generation_config.use_cache:
@@ -332,7 +505,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
past_key_values=past_key_values,
use_cache=True,
)
if hasattr(outs, "logits"):
# Returned outputs is `CausalLMOutputWithPast`
return outs.logits
else:
# Returned the `last_hidden_state` from `BaseModelOutputWithPast`
return outs.last_hidden_state
@staticmethod
def generate(

View File

@@ -147,8 +147,11 @@ class Idefics2VisionEmbeddings(nn.Module):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

View File

@@ -147,8 +147,11 @@ class Idefics3VisionEmbeddings(nn.Module):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
@@ -558,10 +561,10 @@ class Idefics3VisionTransformer(Idefics3PreTrainedModel):
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
elif not self._use_flash_attention_2:
if not self._use_flash_attention_2:
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
elif not torch.any(~patch_attention_mask):
patch_attention_mask = None
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,

View File

@@ -142,8 +142,11 @@ class SmolVLMVisionEmbeddings(nn.Module):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
@@ -445,10 +448,10 @@ class SmolVLMVisionTransformer(SmolVLMPreTrainedModel):
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
elif not self._use_flash_attention_2:
if not self._use_flash_attention_2:
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
elif not torch.any(~patch_attention_mask):
patch_attention_mask = None
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,

View File

@@ -595,3 +595,80 @@ class SmolVLMForConditionalGenerationIntegrationTest(unittest.TestCase):
expected_generated_text = 'User: You are provided the following series of nine frames from a 0:00:09 [H:MM:SS] video.\n\nFrame from 00:00:\nFrame from 00:01:\nFrame from 00:02:\nFrame from 00:03:\nFrame from 00:04:\nFrame from 00:05:\nFrame from 00:06:\nFrame from 00:08:\nFrame from 00:09:\n\nDescribe this video in detail\nAssistant: The video depicts a large language model architecture, specifically a language model with a "quick brown" feature' # fmt: skip
self.assertEqual(generated_texts[0], expected_generated_text)
@slow
def test_export_smolvlm_vision_encoder(self):
from transformers import AutoConfig
from transformers.integrations.executorch import TorchExportableModuleForVLM
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
config = AutoConfig.from_pretrained(model_id)
config.text_config._flash_attn_2_enabled = False
# Load model and extract vision encoder
model = SmolVLMForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
config=config,
)
exportable_module = TorchExportableModuleForVLM(model)
exported_program = exportable_module.export_vision_encoder()
self.assertIsInstance(exported_program, torch.export.ExportedProgram)
@slow
def test_export_smolvlm_connector(self):
from transformers import AutoConfig
from transformers.integrations.executorch import TorchExportableModuleForVLM
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
config = AutoConfig.from_pretrained(model_id)
config.text_config._flash_attn_2_enabled = False
# Load the model and extract the connector (multi-modal projector)
model = SmolVLMForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
config=config,
)
connector = model.model.connector
connector.eval()
exportable_module = TorchExportableModuleForVLM(model)
exported_program = exportable_module.export_connector()
self.assertIsInstance(exported_program, torch.export.ExportedProgram)
@slow
def test_export_smolvlm_text_decoder(self):
from transformers import AutoConfig
from transformers.integrations.executorch import TorchExportableModuleForVLM
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
config = AutoConfig.from_pretrained(model_id)
config.text_config._flash_attn_2_enabled = False
config.text_config.use_cache = True
config.text_config.attn_implementation = "sdpa"
# Load the model and extract the text decoder
model = SmolVLMForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
config=config,
)
text_decoder = model.model.text_model
text_decoder.eval()
exportable_module = TorchExportableModuleForVLM(model)
exported_program = exportable_module.export_text_decoder()
self.assertIsInstance(exported_program, torch.export.ExportedProgram)