Export T5 (encoder-decoder) to ExecuTorch (#36486)

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang
2025-03-31 03:10:26 -07:00
committed by GitHub
parent 475664e2c6
commit 3b07ca78bb
2 changed files with 324 additions and 0 deletions

View File

@@ -22,6 +22,7 @@ import unittest
from transformers import T5Config, is_torch_available
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import (
require_accelerate,
require_sentencepiece,
@@ -1698,6 +1699,150 @@ class T5ModelIntegrationTests(unittest.TestCase):
logits_compiled = model(**inputs)
torch.testing.assert_close(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], rtol=1e-5, atol=1e-5)
@slow
def test_export_encoder(self):
"""Test exporting T5EncoderModel to torch export format."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")
from transformers.integrations.executorch import Seq2SeqLMEncoderExportableModule
model_id = "google-t5/t5-small"
device = "cpu"
example_input_ids = torch.ones((1, 10), dtype=torch.long).to(device)
# Load model
model = T5EncoderModel.from_pretrained(model_id).to(device=device).eval()
# Get original output for comparison
with torch.no_grad():
original_output = model(input_ids=example_input_ids).last_hidden_state
encoder_model = Seq2SeqLMEncoderExportableModule(model)
# Export the encoder_model
with torch.no_grad():
seq_len_dim = torch.export.Dim("sequence_length", max=4096)
exported_program = torch.export.export(
encoder_model, (example_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
)
# Test the exported model
with torch.no_grad():
exported_output = exported_program.module()(example_input_ids)
# Verify outputs are close enough
self.assertTrue(torch.allclose(original_output, exported_output, atol=1e-5))
@slow
def test_export_decoder(self):
"""Test exporting T5 decoder with static cache to torch export format."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")
from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration
from transformers.integrations.executorch import Seq2SeqLMDecoderExportableModuleWithStaticCache
model_id = "google-t5/t5-small"
# Configuration for static cache
batch_size = 1
max_cache_len = 123
device = "cpu"
full_model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
self.assertIsInstance(full_model, T5ForConditionalGeneration)
decoder_model = (
Seq2SeqLMDecoderExportableModuleWithStaticCache(full_model, max_cache_len, batch_size).to(device).eval()
)
# Prepare test inputs
example_decoder_input_ids = torch.tensor([[0]], dtype=torch.long) # Start token
example_cache_position = torch.tensor([0], dtype=torch.long)
# For T5-small, hidden size is 512
example_encoder_hidden_states = torch.zeros((batch_size, 10, 512), dtype=torch.float32)
# Export the model
with torch.no_grad():
encoder_sequence_length_dim = torch.export.Dim("encoder_sequence_length", max=4096)
exported_program = torch.export.export(
decoder_model,
(example_decoder_input_ids, example_encoder_hidden_states, example_cache_position),
dynamic_shapes={
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_sequence_length_dim},
"cache_position": None,
},
strict=True,
)
# We won't directly verify outputs here as it's complicated with caching,
# but we'll check the export was successful
self.assertIsNotNone(exported_program)
# Verify cache buffers existence and shapes
cache_buffers = [
(name, buffer)
for name, buffer in exported_program.named_buffers()
if name.startswith("key_cache_") or name.startswith("value_cache_")
]
# Verify cache buffers
self.assertTrue(len(cache_buffers) > 0, "No cache buffers found in exported model")
for name, buffer in cache_buffers:
# Verify cache buffers are 3D
self.assertEqual(buffer.shape[2], max_cache_len)
@slow
def test_export_t5_summarization(self):
"""Test composing exported T5 encoder and decoder for summarization."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration
from transformers.integrations.executorch import Seq2SeqLMExportableModule
device = "cpu"
batch_size = 1
max_cache_length = 1234
max_hidden_seq_length = 5678
model_id = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
full_model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device).eval()
self.assertIsInstance(full_model, T5ForConditionalGeneration)
wrapped_model = Seq2SeqLMExportableModule(
full_model,
batch_size=batch_size,
max_hidden_seq_length=max_hidden_seq_length,
max_cache_length=max_cache_length,
)
exported_t5 = wrapped_model.export()
# Test Summarization with Composed Models
prompts = [
"summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativity is not hard to grasp."
]
input_ids = tokenizer(prompts, return_tensors="pt").input_ids
generated_ids = exported_t5.generate(prompt_token_ids=input_ids, max_new_tokens=max_cache_length)
generated_summary = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Also run original model for comparison
original_model = T5ForConditionalGeneration.from_pretrained(model_id).eval()
with torch.no_grad():
original_outputs = original_model.generate(input_ids, max_length=50, num_beams=1)
original_summary = tokenizer.decode(original_outputs[0], skip_special_tokens=True)
# Basic verification that we got a reasonable summary
self.assertEqual(generated_summary, original_summary)
@require_torch
class TestAsymmetricT5(unittest.TestCase):