Gemma3 is Torch Exportable (#37728)

* Gemma3 is Torch Exportable

* Expand the support to other mdoels using HybridCache

---------

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang
2025-04-28 00:36:46 -07:00
committed by GitHub
parent 397a5ede33
commit 816b37010c
9 changed files with 369 additions and 10 deletions

View File

@@ -337,6 +337,44 @@ class Gemma2IntegrationTest(unittest.TestCase):
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
@slow
@require_read_token
def test_export_hybrid_cache(self):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
from transformers.pytorch_utils import is_torch_greater_or_equal
if not is_torch_greater_or_equal("2.6.0"):
self.skipTest(reason="This test requires torch >= 2.6 to run.")
model_id = "google/gemma-2-2b"
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertEqual(model.config.cache_implementation, "hybrid")
# Export + HybridCache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
# Test generation with the exported model
prompt = "What is the capital of France?"
max_new_tokens_to_generate = 20
# Generate text with the exported model
tokenizer = AutoTokenizer.from_pretrained(model_id)
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
)
input_text = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
eager_outputs = model.generate(
**input_text,
max_new_tokens=max_new_tokens_to_generate,
do_sample=False, # Use greedy decoding to match the exported model
)
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
self.assertEqual(export_generated_text, eager_generated_text)
@require_read_token
@tooslow
def test_model_9b_bf16_flex_attention(self):