Fix convert_and_export_with_cache failures for GPU models (#38976)

* Add the `device` option for `generate()`

* Add device for default tensors to avoid tensor mismatch

* [test] Enable test_static_cache_exportability for torch_device

* infer device from the prompt_token_ids

* Add device for generated tensor

* [Test] Make `test_export_static_cache` tests to run on devices rather than only CPU

* fix format

* infer device from the model
This commit is contained in:
Stonepia
2025-07-17 21:12:32 +08:00
committed by GitHub
parent 54680d75c9
commit fc700c2a26
12 changed files with 57 additions and 27 deletions

View File

@@ -107,9 +107,23 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`.
"""
if hasattr(self.model, "base_model_prefix"):
base = getattr(self.model, self.model.base_model_prefix, self.model)
model_device = base.device
elif hasattr(self.model, "model"):
model_device = self.model.model.device
else:
model_device = "cpu"
logging.warning(
"TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
)
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
example_input_ids = (
input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device)
)
example_cache_position = (
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device)
)
exported_program = torch.export.export(
self.model,
@@ -322,7 +336,9 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
@staticmethod
def generate(
exported_program: torch.export.ExportedProgram, prompt_token_ids: torch.Tensor, max_new_tokens: int
exported_program: torch.export.ExportedProgram,
prompt_token_ids: torch.Tensor,
max_new_tokens: int,
) -> torch.Tensor:
"""
Generate a sequence of tokens using an exported program.
@@ -341,6 +357,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
Returns:
torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
"""
device = prompt_token_ids.device
prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers():
@@ -353,7 +370,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
for input_pos in range(min(max_generation_length, prompt_token_len)):
result = exported_program.module().forward(
input_ids=prompt_token_ids[:, input_pos : input_pos + 1],
cache_position=torch.tensor([input_pos], dtype=torch.long),
cache_position=torch.tensor([input_pos], dtype=torch.long, device=device),
)
response_tokens.append(prompt_token_ids[0][input_pos].item())
@@ -362,13 +379,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
while len(response_tokens) < max_generation_length:
result = exported_program.module().forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
cache_position=torch.tensor([len(response_tokens)], dtype=torch.long),
input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device),
cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device),
)
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)
return torch.tensor([response_tokens], dtype=torch.long)
return torch.tensor([response_tokens], dtype=torch.long, device=device)
class TorchExportableModuleWithHybridCache(torch.nn.Module):
@@ -484,10 +501,14 @@ def convert_and_export_with_cache(
with torch.no_grad():
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
example_input_ids = (
example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_input_ids
if example_input_ids is not None
else torch.tensor([[1]], dtype=torch.long, device=model.device)
)
example_cache_position = (
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
example_cache_position
if example_cache_position is not None
else torch.tensor([0], dtype=torch.long, device=model.device)
)
if is_torch_greater_or_equal("2.6.0"):
@@ -602,7 +623,7 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
self.exported_decoder = None
def _export_encoder(self, encoder_input_ids):
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to(self.full_model.device).eval()
# Define dynamic sequence length for encoder
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
@@ -645,18 +666,27 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
return exported_decoder
def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None):
device = self.full_model.device
example_encoder_input_ids = (
encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long)
encoder_input_ids
if encoder_input_ids is not None
else torch.ones((1, 10), dtype=torch.long, device=device)
)
example_decoder_input_ids = (
decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long)
decoder_input_ids
if decoder_input_ids is not None
else torch.tensor([[0]], dtype=torch.long, device=device)
) # Start token
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
example_cache_position = (
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=device)
)
example_encoder_hidden_states = (
encoder_hidden_states
if encoder_hidden_states is not None
else torch.zeros(
(self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32
(self.generation_config.cache_config.batch_size, 10, self.config.d_model),
dtype=torch.float32,
device=device,
)
)
self.exported_encoder = self._export_encoder(example_encoder_input_ids)

View File

@@ -248,7 +248,7 @@ class Cohere2IntegrationTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>", padding_side="right")
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -423,7 +423,7 @@ class GemmaIntegrationTest(unittest.TestCase):
].shape[-1]
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -335,7 +335,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
].shape[-1]
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -322,7 +322,7 @@ class LlamaIntegrationTest(unittest.TestCase):
].shape[-1]
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -347,7 +347,7 @@ class OlmoIntegrationTest(unittest.TestCase):
].shape[-1]
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -348,7 +348,7 @@ class Olmo2IntegrationTest(unittest.TestCase):
].shape[-1]
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -384,7 +384,7 @@ class Phi3IntegrationTest(unittest.TestCase):
config.rope_scaling["type"] = "default"
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -270,7 +270,7 @@ class Qwen2IntegrationTest(unittest.TestCase):
].shape[-1]
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -261,7 +261,7 @@ class Qwen3IntegrationTest(unittest.TestCase):
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"

View File

@@ -1774,7 +1774,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration
from transformers.integrations.executorch import Seq2SeqLMExportableModule
device = "cpu"
device = torch_device
batch_size = 1
max_cache_length = 1234
max_hidden_seq_length = 5678

View File

@@ -700,7 +700,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
set_seed(0)
device = "cpu"
device = torch_device
dtype = "bfloat16"
cache_implementation = "static"
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
@@ -748,8 +748,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
# Export with dynamic shapes
input_ids = torch.zeros((1, 3), dtype=torch.long)
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
input_ids = torch.zeros((1, 3), dtype=torch.long, device=device)
cache_position = torch.tensor([0, 1, 2], dtype=torch.long, device=device)
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
strict = version.parse(torch.__version__) != version.parse("2.7.0")
exported_program = convert_and_export_with_cache(