Fix convert_and_export_with_cache failures for GPU models (#38976)
* Add the `device` option for `generate()` * Add device for default tensors to avoid tensor mismatch * [test] Enable test_static_cache_exportability for torch_device * infer device from the prompt_token_ids * Add device for generated tensor * [Test] Make `test_export_static_cache` tests to run on devices rather than only CPU * fix format * infer device from the model
This commit is contained in:
@@ -107,9 +107,23 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
|||||||
strict(`Optional[bool]`):
|
strict(`Optional[bool]`):
|
||||||
Flag to instruct `torch.export` to use `torchdynamo`.
|
Flag to instruct `torch.export` to use `torchdynamo`.
|
||||||
"""
|
"""
|
||||||
|
if hasattr(self.model, "base_model_prefix"):
|
||||||
|
base = getattr(self.model, self.model.base_model_prefix, self.model)
|
||||||
|
model_device = base.device
|
||||||
|
elif hasattr(self.model, "model"):
|
||||||
|
model_device = self.model.model.device
|
||||||
|
else:
|
||||||
|
model_device = "cpu"
|
||||||
|
logging.warning(
|
||||||
|
"TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
|
||||||
|
)
|
||||||
|
|
||||||
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
|
example_input_ids = (
|
||||||
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
|
input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device)
|
||||||
|
)
|
||||||
|
example_cache_position = (
|
||||||
|
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device)
|
||||||
|
)
|
||||||
|
|
||||||
exported_program = torch.export.export(
|
exported_program = torch.export.export(
|
||||||
self.model,
|
self.model,
|
||||||
@@ -322,7 +336,9 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(
|
def generate(
|
||||||
exported_program: torch.export.ExportedProgram, prompt_token_ids: torch.Tensor, max_new_tokens: int
|
exported_program: torch.export.ExportedProgram,
|
||||||
|
prompt_token_ids: torch.Tensor,
|
||||||
|
max_new_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Generate a sequence of tokens using an exported program.
|
Generate a sequence of tokens using an exported program.
|
||||||
@@ -341,6 +357,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
|
torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
|
||||||
"""
|
"""
|
||||||
|
device = prompt_token_ids.device
|
||||||
prompt_token_len = prompt_token_ids.shape[-1]
|
prompt_token_len = prompt_token_ids.shape[-1]
|
||||||
max_generation_length = prompt_token_len + max_new_tokens
|
max_generation_length = prompt_token_len + max_new_tokens
|
||||||
for buffer_name, buffer in exported_program.named_buffers():
|
for buffer_name, buffer in exported_program.named_buffers():
|
||||||
@@ -353,7 +370,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
for input_pos in range(min(max_generation_length, prompt_token_len)):
|
for input_pos in range(min(max_generation_length, prompt_token_len)):
|
||||||
result = exported_program.module().forward(
|
result = exported_program.module().forward(
|
||||||
input_ids=prompt_token_ids[:, input_pos : input_pos + 1],
|
input_ids=prompt_token_ids[:, input_pos : input_pos + 1],
|
||||||
cache_position=torch.tensor([input_pos], dtype=torch.long),
|
cache_position=torch.tensor([input_pos], dtype=torch.long, device=device),
|
||||||
)
|
)
|
||||||
response_tokens.append(prompt_token_ids[0][input_pos].item())
|
response_tokens.append(prompt_token_ids[0][input_pos].item())
|
||||||
|
|
||||||
@@ -362,13 +379,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
|
|
||||||
while len(response_tokens) < max_generation_length:
|
while len(response_tokens) < max_generation_length:
|
||||||
result = exported_program.module().forward(
|
result = exported_program.module().forward(
|
||||||
input_ids=torch.tensor([[current_token]], dtype=torch.long),
|
input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device),
|
||||||
cache_position=torch.tensor([len(response_tokens)], dtype=torch.long),
|
cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device),
|
||||||
)
|
)
|
||||||
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
|
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
|
||||||
response_tokens.append(current_token)
|
response_tokens.append(current_token)
|
||||||
|
|
||||||
return torch.tensor([response_tokens], dtype=torch.long)
|
return torch.tensor([response_tokens], dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||||
@@ -484,10 +501,14 @@ def convert_and_export_with_cache(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
|
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
|
||||||
example_input_ids = (
|
example_input_ids = (
|
||||||
example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long)
|
example_input_ids
|
||||||
|
if example_input_ids is not None
|
||||||
|
else torch.tensor([[1]], dtype=torch.long, device=model.device)
|
||||||
)
|
)
|
||||||
example_cache_position = (
|
example_cache_position = (
|
||||||
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
example_cache_position
|
||||||
|
if example_cache_position is not None
|
||||||
|
else torch.tensor([0], dtype=torch.long, device=model.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_greater_or_equal("2.6.0"):
|
if is_torch_greater_or_equal("2.6.0"):
|
||||||
@@ -602,7 +623,7 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
|
|||||||
self.exported_decoder = None
|
self.exported_decoder = None
|
||||||
|
|
||||||
def _export_encoder(self, encoder_input_ids):
|
def _export_encoder(self, encoder_input_ids):
|
||||||
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()
|
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to(self.full_model.device).eval()
|
||||||
|
|
||||||
# Define dynamic sequence length for encoder
|
# Define dynamic sequence length for encoder
|
||||||
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
|
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
|
||||||
@@ -645,18 +666,27 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
|
|||||||
return exported_decoder
|
return exported_decoder
|
||||||
|
|
||||||
def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None):
|
def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None):
|
||||||
|
device = self.full_model.device
|
||||||
example_encoder_input_ids = (
|
example_encoder_input_ids = (
|
||||||
encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long)
|
encoder_input_ids
|
||||||
|
if encoder_input_ids is not None
|
||||||
|
else torch.ones((1, 10), dtype=torch.long, device=device)
|
||||||
)
|
)
|
||||||
example_decoder_input_ids = (
|
example_decoder_input_ids = (
|
||||||
decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long)
|
decoder_input_ids
|
||||||
|
if decoder_input_ids is not None
|
||||||
|
else torch.tensor([[0]], dtype=torch.long, device=device)
|
||||||
) # Start token
|
) # Start token
|
||||||
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
|
example_cache_position = (
|
||||||
|
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=device)
|
||||||
|
)
|
||||||
example_encoder_hidden_states = (
|
example_encoder_hidden_states = (
|
||||||
encoder_hidden_states
|
encoder_hidden_states
|
||||||
if encoder_hidden_states is not None
|
if encoder_hidden_states is not None
|
||||||
else torch.zeros(
|
else torch.zeros(
|
||||||
(self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32
|
(self.generation_config.cache_config.batch_size, 10, self.config.d_model),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.exported_encoder = self._export_encoder(example_encoder_input_ids)
|
self.exported_encoder = self._export_encoder(example_encoder_input_ids)
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>", padding_side="right")
|
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>", padding_side="right")
|
||||||
# Load model
|
# Load model
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -423,7 +423,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
|||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -347,7 +347,7 @@ class OlmoIntegrationTest(unittest.TestCase):
|
|||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -348,7 +348,7 @@ class Olmo2IntegrationTest(unittest.TestCase):
|
|||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -384,7 +384,7 @@ class Phi3IntegrationTest(unittest.TestCase):
|
|||||||
config.rope_scaling["type"] = "default"
|
config.rope_scaling["type"] = "default"
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -270,7 +270,7 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
|||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
].shape[-1]
|
].shape[-1]
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
|
|||||||
@@ -1774,7 +1774,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration
|
||||||
from transformers.integrations.executorch import Seq2SeqLMExportableModule
|
from transformers.integrations.executorch import Seq2SeqLMExportableModule
|
||||||
|
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
max_cache_length = 1234
|
max_cache_length = 1234
|
||||||
max_hidden_seq_length = 5678
|
max_hidden_seq_length = 5678
|
||||||
|
|||||||
@@ -700,7 +700,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
device = "cpu"
|
device = torch_device
|
||||||
dtype = "bfloat16"
|
dtype = "bfloat16"
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||||
@@ -748,8 +748,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||||
|
|
||||||
# Export with dynamic shapes
|
# Export with dynamic shapes
|
||||||
input_ids = torch.zeros((1, 3), dtype=torch.long)
|
input_ids = torch.zeros((1, 3), dtype=torch.long, device=device)
|
||||||
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
|
cache_position = torch.tensor([0, 1, 2], dtype=torch.long, device=device)
|
||||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
|
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
|
||||||
strict = version.parse(torch.__version__) != version.parse("2.7.0")
|
strict = version.parse(torch.__version__) != version.parse("2.7.0")
|
||||||
exported_program = convert_and_export_with_cache(
|
exported_program = convert_and_export_with_cache(
|
||||||
|
|||||||
Reference in New Issue
Block a user