Fix convert_and_export_with_cache failures for GPU models (#38976)
* Add the `device` option for `generate()` * Add device for default tensors to avoid tensor mismatch * [test] Enable test_static_cache_exportability for torch_device * infer device from the prompt_token_ids * Add device for generated tensor * [Test] Make `test_export_static_cache` tests to run on devices rather than only CPU * fix format * infer device from the model
This commit is contained in:
@@ -107,9 +107,23 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
strict(`Optional[bool]`):
|
||||
Flag to instruct `torch.export` to use `torchdynamo`.
|
||||
"""
|
||||
if hasattr(self.model, "base_model_prefix"):
|
||||
base = getattr(self.model, self.model.base_model_prefix, self.model)
|
||||
model_device = base.device
|
||||
elif hasattr(self.model, "model"):
|
||||
model_device = self.model.model.device
|
||||
else:
|
||||
model_device = "cpu"
|
||||
logging.warning(
|
||||
"TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
|
||||
)
|
||||
|
||||
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
|
||||
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||
example_input_ids = (
|
||||
input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device)
|
||||
)
|
||||
example_cache_position = (
|
||||
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device)
|
||||
)
|
||||
|
||||
exported_program = torch.export.export(
|
||||
self.model,
|
||||
@@ -322,7 +336,9 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
exported_program: torch.export.ExportedProgram, prompt_token_ids: torch.Tensor, max_new_tokens: int
|
||||
exported_program: torch.export.ExportedProgram,
|
||||
prompt_token_ids: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate a sequence of tokens using an exported program.
|
||||
@@ -341,6 +357,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
|
||||
"""
|
||||
device = prompt_token_ids.device
|
||||
prompt_token_len = prompt_token_ids.shape[-1]
|
||||
max_generation_length = prompt_token_len + max_new_tokens
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
@@ -353,7 +370,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
for input_pos in range(min(max_generation_length, prompt_token_len)):
|
||||
result = exported_program.module().forward(
|
||||
input_ids=prompt_token_ids[:, input_pos : input_pos + 1],
|
||||
cache_position=torch.tensor([input_pos], dtype=torch.long),
|
||||
cache_position=torch.tensor([input_pos], dtype=torch.long, device=device),
|
||||
)
|
||||
response_tokens.append(prompt_token_ids[0][input_pos].item())
|
||||
|
||||
@@ -362,13 +379,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
|
||||
while len(response_tokens) < max_generation_length:
|
||||
result = exported_program.module().forward(
|
||||
input_ids=torch.tensor([[current_token]], dtype=torch.long),
|
||||
cache_position=torch.tensor([len(response_tokens)], dtype=torch.long),
|
||||
input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device),
|
||||
cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device),
|
||||
)
|
||||
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
|
||||
response_tokens.append(current_token)
|
||||
|
||||
return torch.tensor([response_tokens], dtype=torch.long)
|
||||
return torch.tensor([response_tokens], dtype=torch.long, device=device)
|
||||
|
||||
|
||||
class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||
@@ -484,10 +501,14 @@ def convert_and_export_with_cache(
|
||||
with torch.no_grad():
|
||||
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
|
||||
example_input_ids = (
|
||||
example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long)
|
||||
example_input_ids
|
||||
if example_input_ids is not None
|
||||
else torch.tensor([[1]], dtype=torch.long, device=model.device)
|
||||
)
|
||||
example_cache_position = (
|
||||
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||
example_cache_position
|
||||
if example_cache_position is not None
|
||||
else torch.tensor([0], dtype=torch.long, device=model.device)
|
||||
)
|
||||
|
||||
if is_torch_greater_or_equal("2.6.0"):
|
||||
@@ -602,7 +623,7 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
|
||||
self.exported_decoder = None
|
||||
|
||||
def _export_encoder(self, encoder_input_ids):
|
||||
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()
|
||||
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to(self.full_model.device).eval()
|
||||
|
||||
# Define dynamic sequence length for encoder
|
||||
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
|
||||
@@ -645,18 +666,27 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
|
||||
return exported_decoder
|
||||
|
||||
def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None):
|
||||
device = self.full_model.device
|
||||
example_encoder_input_ids = (
|
||||
encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long)
|
||||
encoder_input_ids
|
||||
if encoder_input_ids is not None
|
||||
else torch.ones((1, 10), dtype=torch.long, device=device)
|
||||
)
|
||||
example_decoder_input_ids = (
|
||||
decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long)
|
||||
decoder_input_ids
|
||||
if decoder_input_ids is not None
|
||||
else torch.tensor([[0]], dtype=torch.long, device=device)
|
||||
) # Start token
|
||||
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||
example_cache_position = (
|
||||
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=device)
|
||||
)
|
||||
example_encoder_hidden_states = (
|
||||
encoder_hidden_states
|
||||
if encoder_hidden_states is not None
|
||||
else torch.zeros(
|
||||
(self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32
|
||||
(self.generation_config.cache_config.batch_size, 10, self.config.d_model),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
self.exported_encoder = self._export_encoder(example_encoder_input_ids)
|
||||
|
||||
@@ -248,7 +248,7 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>", padding_side="right")
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -423,7 +423,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -335,7 +335,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -322,7 +322,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -347,7 +347,7 @@ class OlmoIntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -348,7 +348,7 @@ class Olmo2IntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -384,7 +384,7 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
config.rope_scaling["type"] = "default"
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -270,7 +270,7 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -261,7 +261,7 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||
"input_ids"
|
||||
].shape[-1]
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -1774,7 +1774,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration
|
||||
from transformers.integrations.executorch import Seq2SeqLMExportableModule
|
||||
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
batch_size = 1
|
||||
max_cache_length = 1234
|
||||
max_hidden_seq_length = 5678
|
||||
|
||||
@@ -700,7 +700,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
set_seed(0)
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = "bfloat16"
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||
@@ -748,8 +748,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||
|
||||
# Export with dynamic shapes
|
||||
input_ids = torch.zeros((1, 3), dtype=torch.long)
|
||||
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||
input_ids = torch.zeros((1, 3), dtype=torch.long, device=device)
|
||||
cache_position = torch.tensor([0, 1, 2], dtype=torch.long, device=device)
|
||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
|
||||
strict = version.parse(torch.__version__) != version.parse("2.7.0")
|
||||
exported_program = convert_and_export_with_cache(
|
||||
|
||||
Reference in New Issue
Block a user