@@ -27,6 +27,167 @@ from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|||||||
from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
||||||
|
|
||||||
|
|
||||||
|
# Add this to src/transformers/integrations/executorch.py
|
||||||
|
|
||||||
|
|
||||||
|
class TorchExportableModuleForVLM:
|
||||||
|
"""
|
||||||
|
A wrapper class for exporting Vision-Language Models (VLMs) like SmolVLM2 for ExecuTorch.
|
||||||
|
|
||||||
|
This class handles the export of three main components:
|
||||||
|
1. Vision encoder (processes images to visual features)
|
||||||
|
2. Connector/projector (maps visual features to text embedding space)
|
||||||
|
3. Text decoder (generates text from combined visual and text tokens)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, max_batch_size: int = 1, max_cache_len: int = 1024):
|
||||||
|
"""
|
||||||
|
Initialize the exportable VLM module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The VLM (e.g. SmolVLM) model instance
|
||||||
|
max_batch_size: Maximum batch size. Always 1 for ExecuTorch
|
||||||
|
max_cache_len: Maximum cache length for text generation
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.max_cache_len = max_cache_len
|
||||||
|
self.config = model.config
|
||||||
|
|
||||||
|
# Extract individual components
|
||||||
|
self.vision_encoder = model.model.vision_model
|
||||||
|
self.connector = model.model.connector
|
||||||
|
self.text_decoder = model.model.text_model
|
||||||
|
|
||||||
|
# Store exported programs
|
||||||
|
self.exported_vision_encoder = None
|
||||||
|
self.exported_connector = None
|
||||||
|
self.exported_text_decoder = None
|
||||||
|
|
||||||
|
def export_vision_encoder(self):
|
||||||
|
"""Export the vision encoder component."""
|
||||||
|
self.vision_encoder.eval()
|
||||||
|
|
||||||
|
# Create example input
|
||||||
|
pixel_values = torch.randn(1, 3, 384, 384, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Define dynamic shapes
|
||||||
|
dynamic_shapes = {
|
||||||
|
"pixel_values": {
|
||||||
|
2: torch.export.Dim.AUTO,
|
||||||
|
3: torch.export.Dim.AUTO,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.exported_vision_encoder = torch.export.export(
|
||||||
|
self.vision_encoder,
|
||||||
|
args=(pixel_values,),
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.exported_vision_encoder
|
||||||
|
|
||||||
|
def export_connector(self):
|
||||||
|
"""Export the connector component."""
|
||||||
|
self.connector.eval()
|
||||||
|
|
||||||
|
# Vision encoder output shape: [batch_size, num_patches, vision_hidden_size]
|
||||||
|
vision_hidden_size = self.config.vision_config.hidden_size
|
||||||
|
image_size = self.config.vision_config.image_size
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_per_dim = image_size // patch_size
|
||||||
|
num_patches = patches_per_dim * patches_per_dim
|
||||||
|
image_hidden_states = torch.randn(1, num_patches, vision_hidden_size, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Define dynamic shapes - static batch_size=1, dynamic num_patches
|
||||||
|
dynamic_shapes = {"image_hidden_states": {1: torch.export.Dim.AUTO}}
|
||||||
|
|
||||||
|
# Export the connector using torch.export
|
||||||
|
self.exported_connector = torch.export.export(
|
||||||
|
self.connector,
|
||||||
|
args=(image_hidden_states,),
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.exported_connector
|
||||||
|
|
||||||
|
def export_text_decoder(self):
|
||||||
|
"""Export the text decoder component."""
|
||||||
|
|
||||||
|
# Create text decoder exportable wrapper
|
||||||
|
self.exportable_text_decoder = TorchExportableModuleForDecoderOnlyLM(
|
||||||
|
model=self.text_decoder,
|
||||||
|
max_batch_size=self.max_batch_size,
|
||||||
|
max_cache_len=self.max_cache_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the existing text decoder exportable wrapper
|
||||||
|
seq_length = 3
|
||||||
|
input_ids = torch.zeros((1, seq_length), dtype=torch.long)
|
||||||
|
cache_position = torch.arange(seq_length, dtype=torch.long)
|
||||||
|
max_seq_length = min(self.max_cache_len, self.config.text_config.max_position_embeddings)
|
||||||
|
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_length - 1)
|
||||||
|
|
||||||
|
dynamic_shapes = {
|
||||||
|
"input_ids": {1: seq_len_dim},
|
||||||
|
"cache_position": {0: seq_len_dim},
|
||||||
|
}
|
||||||
|
|
||||||
|
self.exported_text_decoder = self.exportable_text_decoder.export(
|
||||||
|
input_ids=input_ids,
|
||||||
|
cache_position=cache_position,
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.exported_text_decoder
|
||||||
|
|
||||||
|
def export(self, **kwargs):
|
||||||
|
"""Export all components of the VLM model."""
|
||||||
|
self.export_vision_encoder(**kwargs)
|
||||||
|
self.export_connector(**kwargs)
|
||||||
|
self.export_text_decoder(**kwargs)
|
||||||
|
return {
|
||||||
|
"vision_encoder": self.exported_vision_encoder,
|
||||||
|
"connector": self.exported_connector,
|
||||||
|
"text_decoder": self.exported_text_decoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, pixel_values, input_ids, cache_position):
|
||||||
|
"""
|
||||||
|
Simplified forward pass for inference with guaranteed non-null input_ids and cache_position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values: Input images [1, channels, height, width] (optional)
|
||||||
|
input_ids: Text token IDs [1, seq_len] (required - won't be None)
|
||||||
|
cache_position: Cache positions [seq_len] (required - won't be None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output with logits for text generation
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self, pixel_values=None, input_ids=None, max_new_tokens=50, do_sample=False, temperature=1.0, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simplified generate method with guaranteed non-null input_ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values: Input images [1, channels, height, width] (optional)
|
||||||
|
input_ids: Initial text tokens [1, seq_len] (required - won't be None)
|
||||||
|
max_new_tokens: Maximum number of tokens to generate
|
||||||
|
do_sample: Whether to use sampling or greedy decoding
|
||||||
|
temperature: Temperature for sampling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated sequences
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||||
@@ -64,7 +225,7 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
|||||||
logging.info(
|
logging.info(
|
||||||
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
|
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
|
||||||
)
|
)
|
||||||
self.model = TorchExportableModuleWithStaticCache(model)
|
self.model = TorchExportableModuleWithStaticCache(model, max_batch_size, max_cache_len)
|
||||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||||
@@ -254,7 +415,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
|
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: PreTrainedModel):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
max_batch_size: int = 1,
|
||||||
|
max_cache_len: int = 4096,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the wrapper module with the pretrained model.
|
Initializes the wrapper module with the pretrained model.
|
||||||
|
|
||||||
@@ -270,9 +436,16 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
if model.generation_config is None:
|
if model.generation_config is None:
|
||||||
raise AssertionError(
|
# Use default generation config if not specified
|
||||||
"The model must have a generation config to be exported with static caching. "
|
model.generation_config = GenerationConfig(
|
||||||
"Please set `generation_config`."
|
use_cache=model.config.use_cache,
|
||||||
|
cache_implementation="static",
|
||||||
|
max_length=max_cache_len,
|
||||||
|
cache_config={
|
||||||
|
"batch_size": max_batch_size,
|
||||||
|
"max_cache_len": max_cache_len,
|
||||||
|
"device": "cpu",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not model.generation_config.use_cache:
|
if not model.generation_config.use_cache:
|
||||||
@@ -332,7 +505,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
if hasattr(outs, "logits"):
|
||||||
|
# Returned outputs is `CausalLMOutputWithPast`
|
||||||
return outs.logits
|
return outs.logits
|
||||||
|
else:
|
||||||
|
# Returned the `last_hidden_state` from `BaseModelOutputWithPast`
|
||||||
|
return outs.last_hidden_state
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(
|
def generate(
|
||||||
|
|||||||
@@ -147,8 +147,11 @@ class Idefics2VisionEmbeddings(nn.Module):
|
|||||||
nb_patches_h = p_attn_mask[:, 0].sum()
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
nb_patches_w = p_attn_mask[0].sum()
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
|
|
||||||
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
|
||||||
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
|
||||||
|
|
||||||
|
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
|
||||||
|
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
|
||||||
|
|
||||||
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
||||||
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
||||||
|
|||||||
@@ -147,8 +147,11 @@ class Idefics3VisionEmbeddings(nn.Module):
|
|||||||
nb_patches_h = p_attn_mask[:, 0].sum()
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
nb_patches_w = p_attn_mask[0].sum()
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
|
|
||||||
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
|
||||||
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
|
||||||
|
|
||||||
|
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
|
||||||
|
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
|
||||||
|
|
||||||
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
||||||
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
||||||
@@ -558,10 +561,10 @@ class Idefics3VisionTransformer(Idefics3PreTrainedModel):
|
|||||||
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||||
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||||
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||||
if not torch.any(~patch_attention_mask):
|
if not self._use_flash_attention_2:
|
||||||
patch_attention_mask = None
|
|
||||||
elif not self._use_flash_attention_2:
|
|
||||||
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
||||||
|
elif not torch.any(~patch_attention_mask):
|
||||||
|
patch_attention_mask = None
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
|
|||||||
@@ -142,8 +142,11 @@ class SmolVLMVisionEmbeddings(nn.Module):
|
|||||||
nb_patches_h = p_attn_mask[:, 0].sum()
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
nb_patches_w = p_attn_mask[0].sum()
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
|
|
||||||
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
|
||||||
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
|
||||||
|
|
||||||
|
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
|
||||||
|
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
|
||||||
|
|
||||||
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
||||||
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
||||||
@@ -445,10 +448,10 @@ class SmolVLMVisionTransformer(SmolVLMPreTrainedModel):
|
|||||||
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||||
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||||
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||||
if not torch.any(~patch_attention_mask):
|
if not self._use_flash_attention_2:
|
||||||
patch_attention_mask = None
|
|
||||||
elif not self._use_flash_attention_2:
|
|
||||||
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
||||||
|
elif not torch.any(~patch_attention_mask):
|
||||||
|
patch_attention_mask = None
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
|
|||||||
@@ -595,3 +595,80 @@ class SmolVLMForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
expected_generated_text = 'User: You are provided the following series of nine frames from a 0:00:09 [H:MM:SS] video.\n\nFrame from 00:00:\nFrame from 00:01:\nFrame from 00:02:\nFrame from 00:03:\nFrame from 00:04:\nFrame from 00:05:\nFrame from 00:06:\nFrame from 00:08:\nFrame from 00:09:\n\nDescribe this video in detail\nAssistant: The video depicts a large language model architecture, specifically a language model with a "quick brown" feature' # fmt: skip
|
expected_generated_text = 'User: You are provided the following series of nine frames from a 0:00:09 [H:MM:SS] video.\n\nFrame from 00:00:\nFrame from 00:01:\nFrame from 00:02:\nFrame from 00:03:\nFrame from 00:04:\nFrame from 00:05:\nFrame from 00:06:\nFrame from 00:08:\nFrame from 00:09:\n\nDescribe this video in detail\nAssistant: The video depicts a large language model architecture, specifically a language model with a "quick brown" feature' # fmt: skip
|
||||||
self.assertEqual(generated_texts[0], expected_generated_text)
|
self.assertEqual(generated_texts[0], expected_generated_text)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_export_smolvlm_vision_encoder(self):
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from transformers.integrations.executorch import TorchExportableModuleForVLM
|
||||||
|
|
||||||
|
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
|
||||||
|
|
||||||
|
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
|
||||||
|
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
|
||||||
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
|
config.text_config._flash_attn_2_enabled = False
|
||||||
|
|
||||||
|
# Load model and extract vision encoder
|
||||||
|
model = SmolVLMForConditionalGeneration.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
exportable_module = TorchExportableModuleForVLM(model)
|
||||||
|
exported_program = exportable_module.export_vision_encoder()
|
||||||
|
self.assertIsInstance(exported_program, torch.export.ExportedProgram)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_export_smolvlm_connector(self):
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from transformers.integrations.executorch import TorchExportableModuleForVLM
|
||||||
|
|
||||||
|
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
|
||||||
|
|
||||||
|
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
|
||||||
|
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
|
||||||
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
|
config.text_config._flash_attn_2_enabled = False
|
||||||
|
|
||||||
|
# Load the model and extract the connector (multi-modal projector)
|
||||||
|
model = SmolVLMForConditionalGeneration.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
connector = model.model.connector
|
||||||
|
connector.eval()
|
||||||
|
|
||||||
|
exportable_module = TorchExportableModuleForVLM(model)
|
||||||
|
exported_program = exportable_module.export_connector()
|
||||||
|
self.assertIsInstance(exported_program, torch.export.ExportedProgram)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_export_smolvlm_text_decoder(self):
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from transformers.integrations.executorch import TorchExportableModuleForVLM
|
||||||
|
|
||||||
|
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
|
||||||
|
|
||||||
|
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
|
||||||
|
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
|
||||||
|
config = AutoConfig.from_pretrained(model_id)
|
||||||
|
config.text_config._flash_attn_2_enabled = False
|
||||||
|
config.text_config.use_cache = True
|
||||||
|
config.text_config.attn_implementation = "sdpa"
|
||||||
|
|
||||||
|
# Load the model and extract the text decoder
|
||||||
|
model = SmolVLMForConditionalGeneration.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_decoder = model.model.text_model
|
||||||
|
text_decoder.eval()
|
||||||
|
|
||||||
|
exportable_module = TorchExportableModuleForVLM(model)
|
||||||
|
exported_program = exportable_module.export_text_decoder()
|
||||||
|
self.assertIsInstance(exported_program, torch.export.ExportedProgram)
|
||||||
|
|||||||
Reference in New Issue
Block a user