Gemma3 is Torch Exportable (#37728)
* Gemma3 is Torch Exportable * Expand the support to other mdoels using HybridCache --------- Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Gemma3 model."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -52,6 +53,7 @@ if is_torch_available():
|
||||
Gemma3Processor,
|
||||
Gemma3TextModel,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal
|
||||
|
||||
|
||||
class Gemma3ModelTester(GemmaModelTester):
|
||||
@@ -664,3 +666,42 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
model.generation_config.transformers_version = "4.49.0"
|
||||
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
|
||||
out = model.generate(**inputs, generation_config=generation_config)
|
||||
|
||||
def test_export_text_only_with_hybrid_cache(self):
|
||||
if not is_torch_greater_or_equal("2.6.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
||||
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
self.assertEqual(model.config.cache_implementation, "hybrid")
|
||||
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
logging.info(f"\nExported program: {exported_program}")
|
||||
|
||||
# Test generation with the exported model
|
||||
prompt = "What is the capital of France?"
|
||||
max_new_tokens_to_generate = 20
|
||||
# Generate text with the exported model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
|
||||
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
|
||||
)
|
||||
logging.info(f"\nExport generated texts: '{export_generated_text}'")
|
||||
|
||||
input_text = tokenizer(prompt, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
eager_outputs = model.generate(
|
||||
**input_text,
|
||||
max_new_tokens=max_new_tokens_to_generate,
|
||||
do_sample=False, # Use greedy decoding to match the exported model
|
||||
)
|
||||
|
||||
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
|
||||
logging.info(f"\nEager generated texts: '{eager_generated_text}'")
|
||||
|
||||
self.assertEqual(export_generated_text, eager_generated_text)
|
||||
|
||||
Reference in New Issue
Block a user