From fc700c2a26af9e1b27162d408e8edfa2903c715f Mon Sep 17 00:00:00 2001 From: Stonepia Date: Thu, 17 Jul 2025 21:12:32 +0800 Subject: [PATCH] Fix convert_and_export_with_cache failures for GPU models (#38976) * Add the `device` option for `generate()` * Add device for default tensors to avoid tensor mismatch * [test] Enable test_static_cache_exportability for torch_device * infer device from the prompt_token_ids * Add device for generated tensor * [Test] Make `test_export_static_cache` tests to run on devices rather than only CPU * fix format * infer device from the model --- src/transformers/integrations/executorch.py | 58 ++++++++++++++----- tests/models/cohere2/test_modeling_cohere2.py | 2 +- tests/models/gemma/test_modeling_gemma.py | 2 +- tests/models/gemma2/test_modeling_gemma2.py | 2 +- tests/models/llama/test_modeling_llama.py | 2 +- tests/models/olmo/test_modeling_olmo.py | 2 +- tests/models/olmo2/test_modeling_olmo2.py | 2 +- tests/models/phi3/test_modeling_phi3.py | 2 +- tests/models/qwen2/test_modeling_qwen2.py | 2 +- tests/models/qwen3/test_modeling_qwen3.py | 2 +- tests/models/t5/test_modeling_t5.py | 2 +- tests/utils/test_cache_utils.py | 6 +- 12 files changed, 57 insertions(+), 27 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 41a30c1374..71777d123c 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -107,9 +107,23 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`. """ + if hasattr(self.model, "base_model_prefix"): + base = getattr(self.model, self.model.base_model_prefix, self.model) + model_device = base.device + elif hasattr(self.model, "model"): + model_device = self.model.model.device + else: + model_device = "cpu" + logging.warning( + "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default." + ) - example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) - example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) + example_input_ids = ( + input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device) + ) + example_cache_position = ( + cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) + ) exported_program = torch.export.export( self.model, @@ -322,7 +336,9 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): @staticmethod def generate( - exported_program: torch.export.ExportedProgram, prompt_token_ids: torch.Tensor, max_new_tokens: int + exported_program: torch.export.ExportedProgram, + prompt_token_ids: torch.Tensor, + max_new_tokens: int, ) -> torch.Tensor: """ Generate a sequence of tokens using an exported program. @@ -341,6 +357,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): Returns: torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens. """ + device = prompt_token_ids.device prompt_token_len = prompt_token_ids.shape[-1] max_generation_length = prompt_token_len + max_new_tokens for buffer_name, buffer in exported_program.named_buffers(): @@ -353,7 +370,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): for input_pos in range(min(max_generation_length, prompt_token_len)): result = exported_program.module().forward( input_ids=prompt_token_ids[:, input_pos : input_pos + 1], - cache_position=torch.tensor([input_pos], dtype=torch.long), + cache_position=torch.tensor([input_pos], dtype=torch.long, device=device), ) response_tokens.append(prompt_token_ids[0][input_pos].item()) @@ -362,13 +379,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): while len(response_tokens) < max_generation_length: result = exported_program.module().forward( - input_ids=torch.tensor([[current_token]], dtype=torch.long), - cache_position=torch.tensor([len(response_tokens)], dtype=torch.long), + input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device), + cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device), ) current_token = torch.argmax(result[:, -1, :], dim=-1).item() response_tokens.append(current_token) - return torch.tensor([response_tokens], dtype=torch.long) + return torch.tensor([response_tokens], dtype=torch.long, device=device) class TorchExportableModuleWithHybridCache(torch.nn.Module): @@ -484,10 +501,14 @@ def convert_and_export_with_cache( with torch.no_grad(): # TODO: The default inputs only work for text models. We need to add support for vision/audio models. example_input_ids = ( - example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long) + example_input_ids + if example_input_ids is not None + else torch.tensor([[1]], dtype=torch.long, device=model.device) ) example_cache_position = ( - example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long) + example_cache_position + if example_cache_position is not None + else torch.tensor([0], dtype=torch.long, device=model.device) ) if is_torch_greater_or_equal("2.6.0"): @@ -602,7 +623,7 @@ class Seq2SeqLMExportableModule(torch.nn.Module): self.exported_decoder = None def _export_encoder(self, encoder_input_ids): - wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() + wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to(self.full_model.device).eval() # Define dynamic sequence length for encoder seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) @@ -645,18 +666,27 @@ class Seq2SeqLMExportableModule(torch.nn.Module): return exported_decoder def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None): + device = self.full_model.device example_encoder_input_ids = ( - encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long) + encoder_input_ids + if encoder_input_ids is not None + else torch.ones((1, 10), dtype=torch.long, device=device) ) example_decoder_input_ids = ( - decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long) + decoder_input_ids + if decoder_input_ids is not None + else torch.tensor([[0]], dtype=torch.long, device=device) ) # Start token - example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) + example_cache_position = ( + cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=device) + ) example_encoder_hidden_states = ( encoder_hidden_states if encoder_hidden_states is not None else torch.zeros( - (self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32 + (self.generation_config.cache_config.batch_size, 10, self.config.d_model), + dtype=torch.float32, + device=device, ) ) self.exported_encoder = self._export_encoder(example_encoder_input_ids) diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index a621881098..2fe532f673 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -248,7 +248,7 @@ class Cohere2IntegrationTest(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="", padding_side="right") # Load model - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index e029b22075..d7f7a0ce0e 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -423,7 +423,7 @@ class GemmaIntegrationTest(unittest.TestCase): ].shape[-1] # Load model - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 0f06ed3cea..76418997da 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -335,7 +335,7 @@ class Gemma2IntegrationTest(unittest.TestCase): ].shape[-1] # Load model - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index fcd060a37b..2ffc423be4 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -322,7 +322,7 @@ class LlamaIntegrationTest(unittest.TestCase): ].shape[-1] # Load model - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index 4e94d23101..eea85c7536 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -347,7 +347,7 @@ class OlmoIntegrationTest(unittest.TestCase): ].shape[-1] # Load model - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/olmo2/test_modeling_olmo2.py b/tests/models/olmo2/test_modeling_olmo2.py index b5147a83cc..29fb3517d6 100644 --- a/tests/models/olmo2/test_modeling_olmo2.py +++ b/tests/models/olmo2/test_modeling_olmo2.py @@ -348,7 +348,7 @@ class Olmo2IntegrationTest(unittest.TestCase): ].shape[-1] # Load model - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 098febd0ee..aec3c30802 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -384,7 +384,7 @@ class Phi3IntegrationTest(unittest.TestCase): config.rope_scaling["type"] = "default" # Load model - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index fb15c21345..d66341901e 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -270,7 +270,7 @@ class Qwen2IntegrationTest(unittest.TestCase): ].shape[-1] # Load model - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 5f961ac79e..424be1c866 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -261,7 +261,7 @@ class Qwen3IntegrationTest(unittest.TestCase): max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ "input_ids" ].shape[-1] - device = "cpu" + device = torch_device dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 2b8f0d9a9e..97b8cc2511 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -1774,7 +1774,7 @@ class T5ModelIntegrationTests(unittest.TestCase): from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration from transformers.integrations.executorch import Seq2SeqLMExportableModule - device = "cpu" + device = torch_device batch_size = 1 max_cache_length = 1234 max_hidden_seq_length = 5678 diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index ed8f7d4da1..26f9f56996 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -700,7 +700,7 @@ class CacheExportIntegrationTest(unittest.TestCase): self.skipTest(reason="This test requires torch >= 2.3 to run.") set_seed(0) - device = "cpu" + device = torch_device dtype = "bfloat16" cache_implementation = "static" attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention @@ -748,8 +748,8 @@ class CacheExportIntegrationTest(unittest.TestCase): self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) # Export with dynamic shapes - input_ids = torch.zeros((1, 3), dtype=torch.long) - cache_position = torch.tensor([0, 1, 2], dtype=torch.long) + input_ids = torch.zeros((1, 3), dtype=torch.long, device=device) + cache_position = torch.tensor([0, 1, 2], dtype=torch.long, device=device) dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}} strict = version.parse(torch.__version__) != version.parse("2.7.0") exported_program = convert_and_export_with_cache(