Gemma3 is Torch Exportable (#37728)
* Gemma3 is Torch Exportable * Expand the support to other mdoels using HybridCache --------- Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
@@ -20,15 +20,207 @@ from ..utils.import_utils import is_torch_available
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import PreTrainedModel, StaticCache
|
from transformers import HybridCache, PreTrainedModel, StaticCache
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
||||||
|
|
||||||
|
|
||||||
|
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||||
|
specifically for decoder-only LM with cache. This module ensures that the
|
||||||
|
exported model is compatible with further lowering and execution in `ExecuTorch`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
max_batch_size: int = 1,
|
||||||
|
max_cache_len: int = 4096,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the exportable module with `HybridCache`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`PreTrainedModel`): The pretrained model to wrap.
|
||||||
|
max_batch_size (int): Maximum batch size for the cache.
|
||||||
|
max_cache_len (int): Maximum sequence length for the cache.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model is configured with a unsupported cache implementation.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if model.config.cache_implementation == "static":
|
||||||
|
self.model = TorchExportableModuleWithStaticCache(model)
|
||||||
|
elif model.config.cache_implementation == "hybrid":
|
||||||
|
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
|
||||||
|
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Logits output from the model.
|
||||||
|
"""
|
||||||
|
return self.model.forward(input_ids, cache_position)
|
||||||
|
|
||||||
|
def export(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
cache_position: Optional[torch.Tensor] = None,
|
||||||
|
dynamic_shapes: Optional[dict] = None,
|
||||||
|
strict: Optional[bool] = None,
|
||||||
|
) -> torch.export.ExportedProgram:
|
||||||
|
"""
|
||||||
|
Export the wrapped module using `torch.export`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (`Optional[torch.Tensor]`):
|
||||||
|
Tensor representing current input token id to the module. If not provided, a default tensor will be used.
|
||||||
|
cache_position (`Optional[torch.Tensor]`):
|
||||||
|
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
|
||||||
|
dynamic_shapes (`Optional[dict]`):
|
||||||
|
Dynamic shapes to use for export if specified.
|
||||||
|
strict(`Optional[bool]`):
|
||||||
|
Flag to instruct `torch.export` to use `torchdynamo`.
|
||||||
|
"""
|
||||||
|
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
|
||||||
|
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||||
|
|
||||||
|
return torch.export.export(
|
||||||
|
self.model,
|
||||||
|
args=(example_input_ids, example_cache_position),
|
||||||
|
kwargs={},
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
strict=strict if strict is not None else True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate(
|
||||||
|
exported_program: torch.export.ExportedProgram,
|
||||||
|
tokenizer,
|
||||||
|
prompt: str,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
do_sample: bool = False,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
device: str = "cpu",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a sequence of tokens using an exported program.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exported_program (`torch.export.ExportedProgram`): The exported model being used for generate.
|
||||||
|
tokenizer: The tokenizer to use.
|
||||||
|
prompt (str): The input prompt.
|
||||||
|
max_new_tokens (int): Maximum number of new tokens to generate.
|
||||||
|
do_sample (bool): Whether to use sampling or greedy decoding.
|
||||||
|
temperature (float): The temperature for sampling.
|
||||||
|
top_k (int): The number of highest probability tokens to keep for top-k sampling.
|
||||||
|
top_p (float): The cumulative probability for nucleus sampling.
|
||||||
|
device (str): The device to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated text.
|
||||||
|
"""
|
||||||
|
# Get the module from the exported program
|
||||||
|
exported_module = exported_program.module()
|
||||||
|
|
||||||
|
# Tokenize the prompt
|
||||||
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||||
|
|
||||||
|
# Initialize with the prompt
|
||||||
|
generated_ids = input_ids.clone()
|
||||||
|
|
||||||
|
# Process the prompt tokens first
|
||||||
|
curr_position = 0
|
||||||
|
for i in range(input_ids.shape[1]):
|
||||||
|
# Process one token at a time
|
||||||
|
curr_input_ids = input_ids[:, i : i + 1]
|
||||||
|
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
_ = exported_module(curr_input_ids, curr_cache_position)
|
||||||
|
curr_position += 1
|
||||||
|
|
||||||
|
# Generate new tokens
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
# Get the last token as input
|
||||||
|
curr_input_ids = generated_ids[:, -1:]
|
||||||
|
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# Forward pass to get next token logits
|
||||||
|
outputs = exported_module(curr_input_ids, curr_cache_position)
|
||||||
|
|
||||||
|
# Get the next token ID
|
||||||
|
if do_sample:
|
||||||
|
# Apply temperature
|
||||||
|
if temperature > 0:
|
||||||
|
logits = outputs / temperature
|
||||||
|
else:
|
||||||
|
logits = outputs
|
||||||
|
|
||||||
|
# Apply top-k filtering
|
||||||
|
if top_k > 0:
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = float("-inf")
|
||||||
|
|
||||||
|
# Apply top-p (nucleus) filtering
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
# Scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
logits[indices_to_remove] = float("-inf")
|
||||||
|
|
||||||
|
# Sample from the filtered distribution
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||||
|
else:
|
||||||
|
# Greedy decoding
|
||||||
|
next_token_id = outputs.argmax(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Ensure next_token_id has the right shape before concatenation
|
||||||
|
if next_token_id.dim() > 2:
|
||||||
|
next_token_id = next_token_id.squeeze(-1)
|
||||||
|
|
||||||
|
# Append to the generated sequence
|
||||||
|
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
|
||||||
|
curr_position += 1
|
||||||
|
|
||||||
|
# Stop if we generate an EOS token
|
||||||
|
if next_token_id.item() == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Decode the generated text
|
||||||
|
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||||
specifically for use with static caching. This module ensures that the exported model
|
specifically for decoder-only LM to `StaticCache`. This module ensures that the
|
||||||
is compatible with further lowering and execution in `ExecuTorch`.
|
exported model is compatible with further lowering and execution in `ExecuTorch`.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This class is specifically designed to support export process using `torch.export`
|
This class is specifically designed to support export process using `torch.export`
|
||||||
@@ -178,6 +370,94 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
return torch.tensor([response_tokens], dtype=torch.long)
|
return torch.tensor([response_tokens], dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||||
|
specifically for decoder-only LM to `HybridCache`. This module ensures that the
|
||||||
|
exported model is compatible with further lowering and execution in `ExecuTorch`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
max_batch_size: int = 1,
|
||||||
|
max_cache_len: int = 4096,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the exportable module with `HybridCache`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`PreTrainedModel`): The pretrained model to wrap.
|
||||||
|
max_batch_size (int): Maximum batch size for the cache.
|
||||||
|
max_cache_len (int): Maximum sequence length for the cache.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the model doesn't have the expected configuration for HybridCache.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# Verify the model is configured for HybridCache
|
||||||
|
if not self.model.config.use_cache:
|
||||||
|
raise AssertionError("Model must have caching enabled")
|
||||||
|
|
||||||
|
if (
|
||||||
|
not hasattr(self.model.config, "cache_implementation")
|
||||||
|
or self.model.config.cache_implementation != "hybrid"
|
||||||
|
):
|
||||||
|
raise AssertionError("Model must use 'hybrid' cache implementation")
|
||||||
|
|
||||||
|
# Initialize the HybridCache
|
||||||
|
self.cache = HybridCache(
|
||||||
|
config=self.model.config,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_cache_len=max_cache_len,
|
||||||
|
device=self.model.device,
|
||||||
|
dtype=self.model.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register all key and value cache tensors as buffers
|
||||||
|
for i in range(len(self.cache.key_cache)):
|
||||||
|
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
|
||||||
|
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
|
||||||
|
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Logits output from the model.
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
|
||||||
|
# Generate position_ids from cache_position
|
||||||
|
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
|
||||||
|
# Create attention mask (always ones for token-by-token generation)
|
||||||
|
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
|
||||||
|
|
||||||
|
# Forward pass with the model
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=self.cache,
|
||||||
|
use_cache=True,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return only the logits to simplify the export
|
||||||
|
return outputs.logits
|
||||||
|
|
||||||
|
|
||||||
def convert_and_export_with_cache(
|
def convert_and_export_with_cache(
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
example_input_ids: Optional[torch.Tensor] = None,
|
example_input_ids: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@@ -351,7 +351,7 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer):
|
|||||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||||
offset = cache_position[-1] - effective_seq_len + 1
|
offset = cache_position[-1] - effective_seq_len + 1
|
||||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||||
offset = max(0, offset)
|
offset = torch.clamp(offset, min=0)
|
||||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||||
mask_indexes = torch.arange(
|
mask_indexes = torch.arange(
|
||||||
|
|||||||
@@ -400,7 +400,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
|||||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||||
offset = cache_position[-1] - effective_seq_len + 1
|
offset = cache_position[-1] - effective_seq_len + 1
|
||||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||||
offset = max(0, offset)
|
offset = torch.clamp(offset, min=0)
|
||||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||||
mask_indexes = torch.arange(
|
mask_indexes = torch.arange(
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||||
offset = cache_position[-1] - effective_seq_len + 1
|
offset = cache_position[-1] - effective_seq_len + 1
|
||||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||||
offset = max(0, offset)
|
offset = torch.clamp(offset, min=0)
|
||||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||||
mask_indexes = torch.arange(
|
mask_indexes = torch.arange(
|
||||||
|
|||||||
@@ -364,7 +364,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||||
offset = cache_position[-1] - effective_seq_len + 1
|
offset = cache_position[-1] - effective_seq_len + 1
|
||||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||||
offset = max(0, offset)
|
offset = torch.clamp(offset, min=0)
|
||||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||||
mask_indexes = torch.arange(
|
mask_indexes = torch.arange(
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ class Gemma3DecoderLayer(nn.Module):
|
|||||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||||
offset = cache_position[-1] - effective_seq_len + 1
|
offset = cache_position[-1] - effective_seq_len + 1
|
||||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||||
offset = max(0, offset)
|
offset = torch.clamp(offset, min=0)
|
||||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||||
mask_indexes = torch.arange(
|
mask_indexes = torch.arange(
|
||||||
|
|||||||
@@ -494,7 +494,7 @@ class Gemma3DecoderLayer(nn.Module):
|
|||||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||||
offset = cache_position[-1] - effective_seq_len + 1
|
offset = cache_position[-1] - effective_seq_len + 1
|
||||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||||
offset = max(0, offset)
|
offset = torch.clamp(offset, min=0)
|
||||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||||
mask_indexes = torch.arange(
|
mask_indexes = torch.arange(
|
||||||
|
|||||||
@@ -337,6 +337,44 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
|||||||
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_read_token
|
||||||
|
def test_export_hybrid_cache(self):
|
||||||
|
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
|
if not is_torch_greater_or_equal("2.6.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
||||||
|
|
||||||
|
model_id = "google/gemma-2-2b"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
self.assertEqual(model.config.cache_implementation, "hybrid")
|
||||||
|
|
||||||
|
# Export + HybridCache
|
||||||
|
model.eval()
|
||||||
|
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||||
|
exported_program = exportable_module.export()
|
||||||
|
|
||||||
|
# Test generation with the exported model
|
||||||
|
prompt = "What is the capital of France?"
|
||||||
|
max_new_tokens_to_generate = 20
|
||||||
|
# Generate text with the exported model
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
|
||||||
|
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
|
||||||
|
)
|
||||||
|
|
||||||
|
input_text = tokenizer(prompt, return_tensors="pt")
|
||||||
|
with torch.no_grad():
|
||||||
|
eager_outputs = model.generate(
|
||||||
|
**input_text,
|
||||||
|
max_new_tokens=max_new_tokens_to_generate,
|
||||||
|
do_sample=False, # Use greedy decoding to match the exported model
|
||||||
|
)
|
||||||
|
|
||||||
|
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
|
||||||
|
self.assertEqual(export_generated_text, eager_generated_text)
|
||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
@tooslow
|
@tooslow
|
||||||
def test_model_9b_bf16_flex_attention(self):
|
def test_model_9b_bf16_flex_attention(self):
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch Gemma3 model."""
|
"""Testing suite for the PyTorch Gemma3 model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -52,6 +53,7 @@ if is_torch_available():
|
|||||||
Gemma3Processor,
|
Gemma3Processor,
|
||||||
Gemma3TextModel,
|
Gemma3TextModel,
|
||||||
)
|
)
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
|
|
||||||
class Gemma3ModelTester(GemmaModelTester):
|
class Gemma3ModelTester(GemmaModelTester):
|
||||||
@@ -664,3 +666,42 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
model.generation_config.transformers_version = "4.49.0"
|
model.generation_config.transformers_version = "4.49.0"
|
||||||
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
|
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
|
||||||
out = model.generate(**inputs, generation_config=generation_config)
|
out = model.generate(**inputs, generation_config=generation_config)
|
||||||
|
|
||||||
|
def test_export_text_only_with_hybrid_cache(self):
|
||||||
|
if not is_torch_greater_or_equal("2.6.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
||||||
|
|
||||||
|
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||||
|
|
||||||
|
model_id = "google/gemma-3-1b-it"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
self.assertEqual(model.config.cache_implementation, "hybrid")
|
||||||
|
|
||||||
|
# Export + HybridCache
|
||||||
|
model.eval()
|
||||||
|
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||||
|
exported_program = exportable_module.export()
|
||||||
|
logging.info(f"\nExported program: {exported_program}")
|
||||||
|
|
||||||
|
# Test generation with the exported model
|
||||||
|
prompt = "What is the capital of France?"
|
||||||
|
max_new_tokens_to_generate = 20
|
||||||
|
# Generate text with the exported model
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
|
||||||
|
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
|
||||||
|
)
|
||||||
|
logging.info(f"\nExport generated texts: '{export_generated_text}'")
|
||||||
|
|
||||||
|
input_text = tokenizer(prompt, return_tensors="pt")
|
||||||
|
with torch.no_grad():
|
||||||
|
eager_outputs = model.generate(
|
||||||
|
**input_text,
|
||||||
|
max_new_tokens=max_new_tokens_to_generate,
|
||||||
|
do_sample=False, # Use greedy decoding to match the exported model
|
||||||
|
)
|
||||||
|
|
||||||
|
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
|
||||||
|
logging.info(f"\nEager generated texts: '{eager_generated_text}'")
|
||||||
|
|
||||||
|
self.assertEqual(export_generated_text, eager_generated_text)
|
||||||
|
|||||||
Reference in New Issue
Block a user