From 3b07ca78bb696825feee3e976795fab58f2b6d0c Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Mon, 31 Mar 2025 03:10:26 -0700 Subject: [PATCH] Export T5 (encoder-decoder) to ExecuTorch (#36486) Co-authored-by: Guang Yang --- src/transformers/integrations/executorch.py | 179 ++++++++++++++++++++ tests/models/t5/test_modeling_t5.py | 145 ++++++++++++++++ 2 files changed, 324 insertions(+) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 4ee525ddf8..09fd0c387f 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -12,6 +12,8 @@ import torch +from transformers.generation.configuration_utils import GenerationConfig + from ..utils.import_utils import is_torch_available @@ -216,3 +218,180 @@ def convert_and_export_with_cache( strict=True, ) return exported_program + + +class Seq2SeqLMEncoderExportableModule(torch.nn.Module): + """ + A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`. + This module ensures that the exported encoder model is compatible with ExecuTorch. + """ + + def __init__(self, encoder_model): + super().__init__() + self.encoder = encoder_model + + def forward(self, input_ids): + return self.encoder(input_ids=input_ids).last_hidden_state + + +class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module): + """ + A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`, + specifically for use with static caching. This module ensures the exported decoder + is compatible with ExecuTorch. + """ + + def __init__(self, model, max_static_cache_length, batch_size): + super().__init__() + + # Get the decoder component + self.decoder = model.get_decoder() + self.lm_head = model.lm_head + self.config = model.config + + # Initialize static cache + self.static_cache = StaticCache( + config=self.config, + max_batch_size=batch_size, + max_cache_len=max_static_cache_length, + device="cpu", + dtype=torch.float32, + ) + + # Register cache buffers to make them exportable + for i in range(len(self.static_cache.key_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + + def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): + # Get outputs from decoder + outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_hidden_states, + past_key_values=self.static_cache, + use_cache=True, + cache_position=cache_position, + ) + + # Apply language model head + lm_logits = self.lm_head(outputs[0]) + + return lm_logits + + +class Seq2SeqLMExportableModule(torch.nn.Module): + def __init__( + self, model, batch_size=1, max_hidden_seq_length=4096, cache_implementation="static", max_cache_length=1024 + ): + super().__init__() + + self.full_model = model + self.encoder = model.get_encoder() + self.config = model.config + self.max_hidden_seq_length = max_hidden_seq_length + self.generation_config = GenerationConfig( + use_cache=True, + max_length=max_cache_length, + cache_implementation=cache_implementation, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_cache_length, + }, + ) + self.exported_encoder = None + self.exported_decoder = None + + def _export_encoder(self, encoder_input_ids): + wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() + + # Define dynamic sequence length for encoder + seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) + + # Export the encoder + with torch.no_grad(): + exported_encoder = torch.export.export( + wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True + ) + + return exported_encoder + + def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position): + wrapped_decoder = ( + Seq2SeqLMDecoderExportableModuleWithStaticCache( + model=self.full_model, + max_static_cache_length=self.generation_config.cache_config.max_cache_len, + batch_size=self.generation_config.cache_config.batch_size, + ) + .to("cpu") + .eval() + ) + + # Define dynamic dimension for encoder output sequence length + encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) + + # Export the decoder + with torch.no_grad(): + exported_decoder = torch.export.export( + wrapped_decoder, + (decoder_input_ids, encoder_hidden_states, cache_position), + dynamic_shapes={ + "decoder_input_ids": None, + "encoder_hidden_states": {1: encoder_seq_len_dim}, + "cache_position": None, + }, + strict=True, + ) + + return exported_decoder + + def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None): + example_encoder_input_ids = ( + encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long) + ) + example_decoder_input_ids = ( + decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long) + ) # Start token + example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) + example_encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else torch.zeros( + (self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32 + ) + ) + self.exported_encoder = self._export_encoder(example_encoder_input_ids) + self.exported_decoder = self._export_decoder( + example_decoder_input_ids, example_encoder_hidden_states, example_cache_position + ) + + # Return self to allow chaining + return self + + def generate(self, prompt_token_ids, max_new_tokens): + with torch.no_grad(): + # Run encoder + encoder_output = self.exported_encoder.module()(prompt_token_ids) + + # Initialize with start token (0 for T5) + decoder_input_ids = torch.tensor([[0]], dtype=torch.long) + generated_ids = [0] + + # Generate tokens one by one + for i in range(max_new_tokens - 1): + # Run decoder for next token prediction + logits = self.exported_decoder.module()( + decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long) + ) + + # Get next token + next_token = torch.argmax(logits[:, -1, :], dim=-1).item() + generated_ids.append(next_token) + + # Update input for next iteration + decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) + + # Check if EOS token + if next_token == self.config.eos_token_id: + break + + return generated_ids diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 6c6d47e8ca..48fe5e8942 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -22,6 +22,7 @@ import unittest from transformers import T5Config, is_torch_available from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import ( require_accelerate, require_sentencepiece, @@ -1698,6 +1699,150 @@ class T5ModelIntegrationTests(unittest.TestCase): logits_compiled = model(**inputs) torch.testing.assert_close(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], rtol=1e-5, atol=1e-5) + @slow + def test_export_encoder(self): + """Test exporting T5EncoderModel to torch export format.""" + if not is_torch_greater_or_equal_than_2_4: + self.skipTest("This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import Seq2SeqLMEncoderExportableModule + + model_id = "google-t5/t5-small" + device = "cpu" + example_input_ids = torch.ones((1, 10), dtype=torch.long).to(device) + + # Load model + model = T5EncoderModel.from_pretrained(model_id).to(device=device).eval() + + # Get original output for comparison + with torch.no_grad(): + original_output = model(input_ids=example_input_ids).last_hidden_state + + encoder_model = Seq2SeqLMEncoderExportableModule(model) + + # Export the encoder_model + with torch.no_grad(): + seq_len_dim = torch.export.Dim("sequence_length", max=4096) + + exported_program = torch.export.export( + encoder_model, (example_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True + ) + + # Test the exported model + with torch.no_grad(): + exported_output = exported_program.module()(example_input_ids) + + # Verify outputs are close enough + self.assertTrue(torch.allclose(original_output, exported_output, atol=1e-5)) + + @slow + def test_export_decoder(self): + """Test exporting T5 decoder with static cache to torch export format.""" + if not is_torch_greater_or_equal_than_2_4: + self.skipTest("This test requires torch >= 2.4 to run.") + + from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration + from transformers.integrations.executorch import Seq2SeqLMDecoderExportableModuleWithStaticCache + + model_id = "google-t5/t5-small" + + # Configuration for static cache + batch_size = 1 + max_cache_len = 123 + device = "cpu" + + full_model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device) + self.assertIsInstance(full_model, T5ForConditionalGeneration) + decoder_model = ( + Seq2SeqLMDecoderExportableModuleWithStaticCache(full_model, max_cache_len, batch_size).to(device).eval() + ) + + # Prepare test inputs + example_decoder_input_ids = torch.tensor([[0]], dtype=torch.long) # Start token + example_cache_position = torch.tensor([0], dtype=torch.long) + + # For T5-small, hidden size is 512 + example_encoder_hidden_states = torch.zeros((batch_size, 10, 512), dtype=torch.float32) + + # Export the model + with torch.no_grad(): + encoder_sequence_length_dim = torch.export.Dim("encoder_sequence_length", max=4096) + + exported_program = torch.export.export( + decoder_model, + (example_decoder_input_ids, example_encoder_hidden_states, example_cache_position), + dynamic_shapes={ + "decoder_input_ids": None, + "encoder_hidden_states": {1: encoder_sequence_length_dim}, + "cache_position": None, + }, + strict=True, + ) + + # We won't directly verify outputs here as it's complicated with caching, + # but we'll check the export was successful + self.assertIsNotNone(exported_program) + + # Verify cache buffers existence and shapes + cache_buffers = [ + (name, buffer) + for name, buffer in exported_program.named_buffers() + if name.startswith("key_cache_") or name.startswith("value_cache_") + ] + + # Verify cache buffers + self.assertTrue(len(cache_buffers) > 0, "No cache buffers found in exported model") + for name, buffer in cache_buffers: + # Verify cache buffers are 3D + self.assertEqual(buffer.shape[2], max_cache_len) + + @slow + def test_export_t5_summarization(self): + """Test composing exported T5 encoder and decoder for summarization.""" + if not is_torch_greater_or_equal_than_2_4: + self.skipTest("This test requires torch >= 2.4 to run.") + + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration + from transformers.integrations.executorch import Seq2SeqLMExportableModule + + device = "cpu" + batch_size = 1 + max_cache_length = 1234 + max_hidden_seq_length = 5678 + model_id = "google-t5/t5-small" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + full_model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device).eval() + self.assertIsInstance(full_model, T5ForConditionalGeneration) + wrapped_model = Seq2SeqLMExportableModule( + full_model, + batch_size=batch_size, + max_hidden_seq_length=max_hidden_seq_length, + max_cache_length=max_cache_length, + ) + + exported_t5 = wrapped_model.export() + + # Test Summarization with Composed Models + prompts = [ + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativity is not hard to grasp." + ] + input_ids = tokenizer(prompts, return_tensors="pt").input_ids + + generated_ids = exported_t5.generate(prompt_token_ids=input_ids, max_new_tokens=max_cache_length) + generated_summary = tokenizer.decode(generated_ids, skip_special_tokens=True) + + # Also run original model for comparison + original_model = T5ForConditionalGeneration.from_pretrained(model_id).eval() + with torch.no_grad(): + original_outputs = original_model.generate(input_ids, max_length=50, num_beams=1) + original_summary = tokenizer.decode(original_outputs[0], skip_special_tokens=True) + + # Basic verification that we got a reasonable summary + self.assertEqual(generated_summary, original_summary) + @require_torch class TestAsymmetricT5(unittest.TestCase):