diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 6fa0e6348d..a56dceaa24 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -27,6 +27,167 @@ from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3 +# Add this to src/transformers/integrations/executorch.py + + +class TorchExportableModuleForVLM: + """ + A wrapper class for exporting Vision-Language Models (VLMs) like SmolVLM2 for ExecuTorch. + + This class handles the export of three main components: + 1. Vision encoder (processes images to visual features) + 2. Connector/projector (maps visual features to text embedding space) + 3. Text decoder (generates text from combined visual and text tokens) + """ + + def __init__(self, model, max_batch_size: int = 1, max_cache_len: int = 1024): + """ + Initialize the exportable VLM module. + + Args: + model: The VLM (e.g. SmolVLM) model instance + max_batch_size: Maximum batch size. Always 1 for ExecuTorch + max_cache_len: Maximum cache length for text generation + """ + self.model = model + self.max_batch_size = max_batch_size + self.max_cache_len = max_cache_len + self.config = model.config + + # Extract individual components + self.vision_encoder = model.model.vision_model + self.connector = model.model.connector + self.text_decoder = model.model.text_model + + # Store exported programs + self.exported_vision_encoder = None + self.exported_connector = None + self.exported_text_decoder = None + + def export_vision_encoder(self): + """Export the vision encoder component.""" + self.vision_encoder.eval() + + # Create example input + pixel_values = torch.randn(1, 3, 384, 384, dtype=torch.float32) + + # Define dynamic shapes + dynamic_shapes = { + "pixel_values": { + 2: torch.export.Dim.AUTO, + 3: torch.export.Dim.AUTO, + } + } + + self.exported_vision_encoder = torch.export.export( + self.vision_encoder, + args=(pixel_values,), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + return self.exported_vision_encoder + + def export_connector(self): + """Export the connector component.""" + self.connector.eval() + + # Vision encoder output shape: [batch_size, num_patches, vision_hidden_size] + vision_hidden_size = self.config.vision_config.hidden_size + image_size = self.config.vision_config.image_size + patch_size = self.config.vision_config.patch_size + patches_per_dim = image_size // patch_size + num_patches = patches_per_dim * patches_per_dim + image_hidden_states = torch.randn(1, num_patches, vision_hidden_size, dtype=torch.float32) + + # Define dynamic shapes - static batch_size=1, dynamic num_patches + dynamic_shapes = {"image_hidden_states": {1: torch.export.Dim.AUTO}} + + # Export the connector using torch.export + self.exported_connector = torch.export.export( + self.connector, + args=(image_hidden_states,), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + return self.exported_connector + + def export_text_decoder(self): + """Export the text decoder component.""" + + # Create text decoder exportable wrapper + self.exportable_text_decoder = TorchExportableModuleForDecoderOnlyLM( + model=self.text_decoder, + max_batch_size=self.max_batch_size, + max_cache_len=self.max_cache_len, + ) + + # Use the existing text decoder exportable wrapper + seq_length = 3 + input_ids = torch.zeros((1, seq_length), dtype=torch.long) + cache_position = torch.arange(seq_length, dtype=torch.long) + max_seq_length = min(self.max_cache_len, self.config.text_config.max_position_embeddings) + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_length - 1) + + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + self.exported_text_decoder = self.exportable_text_decoder.export( + input_ids=input_ids, + cache_position=cache_position, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + return self.exported_text_decoder + + def export(self, **kwargs): + """Export all components of the VLM model.""" + self.export_vision_encoder(**kwargs) + self.export_connector(**kwargs) + self.export_text_decoder(**kwargs) + return { + "vision_encoder": self.exported_vision_encoder, + "connector": self.exported_connector, + "text_decoder": self.exported_text_decoder, + } + + def forward(self, pixel_values, input_ids, cache_position): + """ + Simplified forward pass for inference with guaranteed non-null input_ids and cache_position. + + Args: + pixel_values: Input images [1, channels, height, width] (optional) + input_ids: Text token IDs [1, seq_len] (required - won't be None) + cache_position: Cache positions [seq_len] (required - won't be None) + + Returns: + Output with logits for text generation + """ + pass + + def generate( + self, pixel_values=None, input_ids=None, max_new_tokens=50, do_sample=False, temperature=1.0, **kwargs + ): + """ + Simplified generate method with guaranteed non-null input_ids. + + Args: + pixel_values: Input images [1, channels, height, width] (optional) + input_ids: Initial text tokens [1, seq_len] (required - won't be None) + max_new_tokens: Maximum number of tokens to generate + do_sample: Whether to use sampling or greedy decoding + temperature: Temperature for sampling + + Returns: + Generated sequences + """ + pass + + class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): """ A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, @@ -64,7 +225,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): logging.info( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) - self.model = TorchExportableModuleWithStaticCache(model) + self.model = TorchExportableModuleWithStaticCache(model, max_batch_size, max_cache_len) # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) @@ -254,7 +415,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`. """ - def __init__(self, model: PreTrainedModel): + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): """ Initializes the wrapper module with the pretrained model. @@ -270,9 +436,16 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): # Sanity checks if model.generation_config is None: - raise AssertionError( - "The model must have a generation config to be exported with static caching. " - "Please set `generation_config`." + # Use default generation config if not specified + model.generation_config = GenerationConfig( + use_cache=model.config.use_cache, + cache_implementation="static", + max_length=max_cache_len, + cache_config={ + "batch_size": max_batch_size, + "max_cache_len": max_cache_len, + "device": "cpu", + }, ) if not model.generation_config.use_cache: @@ -332,7 +505,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): past_key_values=past_key_values, use_cache=True, ) - return outs.logits + if hasattr(outs, "logits"): + # Returned outputs is `CausalLMOutputWithPast` + return outs.logits + else: + # Returned the `last_hidden_state` from `BaseModelOutputWithPast` + return outs.last_hidden_state @staticmethod def generate( diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index a2afefa827..d25cf5e2f2 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -147,8 +147,11 @@ class Idefics2VisionEmbeddings(nn.Module): nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) - fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype) + w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype) + + fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) + fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 9200bb7159..c2d41aac02 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -147,8 +147,11 @@ class Idefics3VisionEmbeddings(nn.Module): nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) - fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype) + w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype) + + fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) + fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) @@ -558,10 +561,10 @@ class Idefics3VisionTransformer(Idefics3PreTrainedModel): # The call to `_upad_input` in `_flash_attention_forward` is expensive # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), # avoiding passing the attention_mask, which is equivalent to attending to the full sequence - if not torch.any(~patch_attention_mask): - patch_attention_mask = None - elif not self._use_flash_attention_2: + if not self._use_flash_attention_2: patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + elif not torch.any(~patch_attention_mask): + patch_attention_mask = None encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index a9d8f043c4..7452068685 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -142,8 +142,11 @@ class SmolVLMVisionEmbeddings(nn.Module): nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) - fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype) + w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype) + + fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) + fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) @@ -445,10 +448,10 @@ class SmolVLMVisionTransformer(SmolVLMPreTrainedModel): # The call to `_upad_input` in `_flash_attention_forward` is expensive # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), # avoiding passing the attention_mask, which is equivalent to attending to the full sequence - if not torch.any(~patch_attention_mask): - patch_attention_mask = None - elif not self._use_flash_attention_2: + if not self._use_flash_attention_2: patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + elif not torch.any(~patch_attention_mask): + patch_attention_mask = None encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/tests/models/smolvlm/test_modeling_smolvlm.py b/tests/models/smolvlm/test_modeling_smolvlm.py index 135043e986..d1140b6ec1 100644 --- a/tests/models/smolvlm/test_modeling_smolvlm.py +++ b/tests/models/smolvlm/test_modeling_smolvlm.py @@ -595,3 +595,80 @@ class SmolVLMForConditionalGenerationIntegrationTest(unittest.TestCase): expected_generated_text = 'User: You are provided the following series of nine frames from a 0:00:09 [H:MM:SS] video.\n\nFrame from 00:00:\nFrame from 00:01:\nFrame from 00:02:\nFrame from 00:03:\nFrame from 00:04:\nFrame from 00:05:\nFrame from 00:06:\nFrame from 00:08:\nFrame from 00:09:\n\nDescribe this video in detail\nAssistant: The video depicts a large language model architecture, specifically a language model with a "quick brown" feature' # fmt: skip self.assertEqual(generated_texts[0], expected_generated_text) + + @slow + def test_export_smolvlm_vision_encoder(self): + from transformers import AutoConfig + from transformers.integrations.executorch import TorchExportableModuleForVLM + + model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct" + + # NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not + # For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder + config = AutoConfig.from_pretrained(model_id) + config.text_config._flash_attn_2_enabled = False + + # Load model and extract vision encoder + model = SmolVLMForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.float32, + config=config, + ) + + exportable_module = TorchExportableModuleForVLM(model) + exported_program = exportable_module.export_vision_encoder() + self.assertIsInstance(exported_program, torch.export.ExportedProgram) + + @slow + def test_export_smolvlm_connector(self): + from transformers import AutoConfig + from transformers.integrations.executorch import TorchExportableModuleForVLM + + model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct" + + # NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not + # For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder + config = AutoConfig.from_pretrained(model_id) + config.text_config._flash_attn_2_enabled = False + + # Load the model and extract the connector (multi-modal projector) + model = SmolVLMForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.float32, + config=config, + ) + + connector = model.model.connector + connector.eval() + + exportable_module = TorchExportableModuleForVLM(model) + exported_program = exportable_module.export_connector() + self.assertIsInstance(exported_program, torch.export.ExportedProgram) + + @slow + def test_export_smolvlm_text_decoder(self): + from transformers import AutoConfig + from transformers.integrations.executorch import TorchExportableModuleForVLM + + model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct" + + # NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not + # For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder + config = AutoConfig.from_pretrained(model_id) + config.text_config._flash_attn_2_enabled = False + config.text_config.use_cache = True + config.text_config.attn_implementation = "sdpa" + + # Load the model and extract the text decoder + model = SmolVLMForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.float32, + config=config, + ) + + text_decoder = model.model.text_model + text_decoder.eval() + + exportable_module = TorchExportableModuleForVLM(model) + exported_program = exportable_module.export_text_decoder() + self.assertIsInstance(exported_program, torch.export.ExportedProgram)