Export T5 (encoder-decoder) to ExecuTorch (#36486)
Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
@@ -12,6 +12,8 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
|
|
||||||
from ..utils.import_utils import is_torch_available
|
from ..utils.import_utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
@@ -216,3 +218,180 @@ def convert_and_export_with_cache(
|
|||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
return exported_program
|
return exported_program
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqLMEncoderExportableModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`.
|
||||||
|
This module ensures that the exported encoder model is compatible with ExecuTorch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, encoder_model):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder_model
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
return self.encoder(input_ids=input_ids).last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`,
|
||||||
|
specifically for use with static caching. This module ensures the exported decoder
|
||||||
|
is compatible with ExecuTorch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, max_static_cache_length, batch_size):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Get the decoder component
|
||||||
|
self.decoder = model.get_decoder()
|
||||||
|
self.lm_head = model.lm_head
|
||||||
|
self.config = model.config
|
||||||
|
|
||||||
|
# Initialize static cache
|
||||||
|
self.static_cache = StaticCache(
|
||||||
|
config=self.config,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
max_cache_len=max_static_cache_length,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register cache buffers to make them exportable
|
||||||
|
for i in range(len(self.static_cache.key_cache)):
|
||||||
|
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
|
||||||
|
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
|
||||||
|
|
||||||
|
def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
|
||||||
|
# Get outputs from decoder
|
||||||
|
outputs = self.decoder(
|
||||||
|
input_ids=decoder_input_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
past_key_values=self.static_cache,
|
||||||
|
use_cache=True,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply language model head
|
||||||
|
lm_logits = self.lm_head(outputs[0])
|
||||||
|
|
||||||
|
return lm_logits
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqLMExportableModule(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, model, batch_size=1, max_hidden_seq_length=4096, cache_implementation="static", max_cache_length=1024
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.full_model = model
|
||||||
|
self.encoder = model.get_encoder()
|
||||||
|
self.config = model.config
|
||||||
|
self.max_hidden_seq_length = max_hidden_seq_length
|
||||||
|
self.generation_config = GenerationConfig(
|
||||||
|
use_cache=True,
|
||||||
|
max_length=max_cache_length,
|
||||||
|
cache_implementation=cache_implementation,
|
||||||
|
cache_config={
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"max_cache_len": max_cache_length,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.exported_encoder = None
|
||||||
|
self.exported_decoder = None
|
||||||
|
|
||||||
|
def _export_encoder(self, encoder_input_ids):
|
||||||
|
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()
|
||||||
|
|
||||||
|
# Define dynamic sequence length for encoder
|
||||||
|
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
|
||||||
|
|
||||||
|
# Export the encoder
|
||||||
|
with torch.no_grad():
|
||||||
|
exported_encoder = torch.export.export(
|
||||||
|
wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return exported_encoder
|
||||||
|
|
||||||
|
def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
|
||||||
|
wrapped_decoder = (
|
||||||
|
Seq2SeqLMDecoderExportableModuleWithStaticCache(
|
||||||
|
model=self.full_model,
|
||||||
|
max_static_cache_length=self.generation_config.cache_config.max_cache_len,
|
||||||
|
batch_size=self.generation_config.cache_config.batch_size,
|
||||||
|
)
|
||||||
|
.to("cpu")
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define dynamic dimension for encoder output sequence length
|
||||||
|
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
|
||||||
|
|
||||||
|
# Export the decoder
|
||||||
|
with torch.no_grad():
|
||||||
|
exported_decoder = torch.export.export(
|
||||||
|
wrapped_decoder,
|
||||||
|
(decoder_input_ids, encoder_hidden_states, cache_position),
|
||||||
|
dynamic_shapes={
|
||||||
|
"decoder_input_ids": None,
|
||||||
|
"encoder_hidden_states": {1: encoder_seq_len_dim},
|
||||||
|
"cache_position": None,
|
||||||
|
},
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return exported_decoder
|
||||||
|
|
||||||
|
def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None):
|
||||||
|
example_encoder_input_ids = (
|
||||||
|
encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long)
|
||||||
|
)
|
||||||
|
example_decoder_input_ids = (
|
||||||
|
decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long)
|
||||||
|
) # Start token
|
||||||
|
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||||
|
example_encoder_hidden_states = (
|
||||||
|
encoder_hidden_states
|
||||||
|
if encoder_hidden_states is not None
|
||||||
|
else torch.zeros(
|
||||||
|
(self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.exported_encoder = self._export_encoder(example_encoder_input_ids)
|
||||||
|
self.exported_decoder = self._export_decoder(
|
||||||
|
example_decoder_input_ids, example_encoder_hidden_states, example_cache_position
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return self to allow chaining
|
||||||
|
return self
|
||||||
|
|
||||||
|
def generate(self, prompt_token_ids, max_new_tokens):
|
||||||
|
with torch.no_grad():
|
||||||
|
# Run encoder
|
||||||
|
encoder_output = self.exported_encoder.module()(prompt_token_ids)
|
||||||
|
|
||||||
|
# Initialize with start token (0 for T5)
|
||||||
|
decoder_input_ids = torch.tensor([[0]], dtype=torch.long)
|
||||||
|
generated_ids = [0]
|
||||||
|
|
||||||
|
# Generate tokens one by one
|
||||||
|
for i in range(max_new_tokens - 1):
|
||||||
|
# Run decoder for next token prediction
|
||||||
|
logits = self.exported_decoder.module()(
|
||||||
|
decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get next token
|
||||||
|
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
|
||||||
|
generated_ids.append(next_token)
|
||||||
|
|
||||||
|
# Update input for next iteration
|
||||||
|
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)
|
||||||
|
|
||||||
|
# Check if EOS token
|
||||||
|
if next_token == self.config.eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
return generated_ids
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import T5Config, is_torch_available
|
from transformers import T5Config, is_torch_available
|
||||||
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
@@ -1698,6 +1699,150 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||||||
logits_compiled = model(**inputs)
|
logits_compiled = model(**inputs)
|
||||||
torch.testing.assert_close(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], rtol=1e-5, atol=1e-5)
|
torch.testing.assert_close(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_export_encoder(self):
|
||||||
|
"""Test exporting T5EncoderModel to torch export format."""
|
||||||
|
if not is_torch_greater_or_equal_than_2_4:
|
||||||
|
self.skipTest("This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
from transformers.integrations.executorch import Seq2SeqLMEncoderExportableModule
|
||||||
|
|
||||||
|
model_id = "google-t5/t5-small"
|
||||||
|
device = "cpu"
|
||||||
|
example_input_ids = torch.ones((1, 10), dtype=torch.long).to(device)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = T5EncoderModel.from_pretrained(model_id).to(device=device).eval()
|
||||||
|
|
||||||
|
# Get original output for comparison
|
||||||
|
with torch.no_grad():
|
||||||
|
original_output = model(input_ids=example_input_ids).last_hidden_state
|
||||||
|
|
||||||
|
encoder_model = Seq2SeqLMEncoderExportableModule(model)
|
||||||
|
|
||||||
|
# Export the encoder_model
|
||||||
|
with torch.no_grad():
|
||||||
|
seq_len_dim = torch.export.Dim("sequence_length", max=4096)
|
||||||
|
|
||||||
|
exported_program = torch.export.export(
|
||||||
|
encoder_model, (example_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test the exported model
|
||||||
|
with torch.no_grad():
|
||||||
|
exported_output = exported_program.module()(example_input_ids)
|
||||||
|
|
||||||
|
# Verify outputs are close enough
|
||||||
|
self.assertTrue(torch.allclose(original_output, exported_output, atol=1e-5))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_export_decoder(self):
|
||||||
|
"""Test exporting T5 decoder with static cache to torch export format."""
|
||||||
|
if not is_torch_greater_or_equal_than_2_4:
|
||||||
|
self.skipTest("This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration
|
||||||
|
from transformers.integrations.executorch import Seq2SeqLMDecoderExportableModuleWithStaticCache
|
||||||
|
|
||||||
|
model_id = "google-t5/t5-small"
|
||||||
|
|
||||||
|
# Configuration for static cache
|
||||||
|
batch_size = 1
|
||||||
|
max_cache_len = 123
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
full_model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
|
||||||
|
self.assertIsInstance(full_model, T5ForConditionalGeneration)
|
||||||
|
decoder_model = (
|
||||||
|
Seq2SeqLMDecoderExportableModuleWithStaticCache(full_model, max_cache_len, batch_size).to(device).eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare test inputs
|
||||||
|
example_decoder_input_ids = torch.tensor([[0]], dtype=torch.long) # Start token
|
||||||
|
example_cache_position = torch.tensor([0], dtype=torch.long)
|
||||||
|
|
||||||
|
# For T5-small, hidden size is 512
|
||||||
|
example_encoder_hidden_states = torch.zeros((batch_size, 10, 512), dtype=torch.float32)
|
||||||
|
|
||||||
|
# Export the model
|
||||||
|
with torch.no_grad():
|
||||||
|
encoder_sequence_length_dim = torch.export.Dim("encoder_sequence_length", max=4096)
|
||||||
|
|
||||||
|
exported_program = torch.export.export(
|
||||||
|
decoder_model,
|
||||||
|
(example_decoder_input_ids, example_encoder_hidden_states, example_cache_position),
|
||||||
|
dynamic_shapes={
|
||||||
|
"decoder_input_ids": None,
|
||||||
|
"encoder_hidden_states": {1: encoder_sequence_length_dim},
|
||||||
|
"cache_position": None,
|
||||||
|
},
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We won't directly verify outputs here as it's complicated with caching,
|
||||||
|
# but we'll check the export was successful
|
||||||
|
self.assertIsNotNone(exported_program)
|
||||||
|
|
||||||
|
# Verify cache buffers existence and shapes
|
||||||
|
cache_buffers = [
|
||||||
|
(name, buffer)
|
||||||
|
for name, buffer in exported_program.named_buffers()
|
||||||
|
if name.startswith("key_cache_") or name.startswith("value_cache_")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify cache buffers
|
||||||
|
self.assertTrue(len(cache_buffers) > 0, "No cache buffers found in exported model")
|
||||||
|
for name, buffer in cache_buffers:
|
||||||
|
# Verify cache buffers are 3D
|
||||||
|
self.assertEqual(buffer.shape[2], max_cache_len)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_export_t5_summarization(self):
|
||||||
|
"""Test composing exported T5 encoder and decoder for summarization."""
|
||||||
|
if not is_torch_greater_or_equal_than_2_4:
|
||||||
|
self.skipTest("This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration
|
||||||
|
from transformers.integrations.executorch import Seq2SeqLMExportableModule
|
||||||
|
|
||||||
|
device = "cpu"
|
||||||
|
batch_size = 1
|
||||||
|
max_cache_length = 1234
|
||||||
|
max_hidden_seq_length = 5678
|
||||||
|
model_id = "google-t5/t5-small"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
full_model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device).eval()
|
||||||
|
self.assertIsInstance(full_model, T5ForConditionalGeneration)
|
||||||
|
wrapped_model = Seq2SeqLMExportableModule(
|
||||||
|
full_model,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_hidden_seq_length=max_hidden_seq_length,
|
||||||
|
max_cache_length=max_cache_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
exported_t5 = wrapped_model.export()
|
||||||
|
|
||||||
|
# Test Summarization with Composed Models
|
||||||
|
prompts = [
|
||||||
|
"summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||||
|
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||||
|
"theory of relativity is not hard to grasp."
|
||||||
|
]
|
||||||
|
input_ids = tokenizer(prompts, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
generated_ids = exported_t5.generate(prompt_token_ids=input_ids, max_new_tokens=max_cache_length)
|
||||||
|
generated_summary = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Also run original model for comparison
|
||||||
|
original_model = T5ForConditionalGeneration.from_pretrained(model_id).eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
original_outputs = original_model.generate(input_ids, max_length=50, num_beams=1)
|
||||||
|
original_summary = tokenizer.decode(original_outputs[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Basic verification that we got a reasonable summary
|
||||||
|
self.assertEqual(generated_summary, original_summary)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TestAsymmetricT5(unittest.TestCase):
|
class TestAsymmetricT5(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user