Fix convert_and_export_with_cache failures for GPU models (#38976)
* Add the `device` option for `generate()` * Add device for default tensors to avoid tensor mismatch * [test] Enable test_static_cache_exportability for torch_device * infer device from the prompt_token_ids * Add device for generated tensor * [Test] Make `test_export_static_cache` tests to run on devices rather than only CPU * fix format * infer device from the model
This commit is contained in:
@@ -248,7 +248,7 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>", padding_side="right")
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -423,7 +423,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -335,7 +335,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -322,7 +322,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -347,7 +347,7 @@ class OlmoIntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -348,7 +348,7 @@ class Olmo2IntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -384,7 +384,7 @@ class Phi3IntegrationTest(unittest.TestCase):
|
||||
config.rope_scaling["type"] = "default"
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -270,7 +270,7 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -261,7 +261,7 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||
"input_ids"
|
||||
].shape[-1]
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
@@ -1774,7 +1774,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration
|
||||
from transformers.integrations.executorch import Seq2SeqLMExportableModule
|
||||
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
batch_size = 1
|
||||
max_cache_length = 1234
|
||||
max_hidden_seq_length = 5678
|
||||
|
||||
@@ -700,7 +700,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
set_seed(0)
|
||||
device = "cpu"
|
||||
device = torch_device
|
||||
dtype = "bfloat16"
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||
@@ -748,8 +748,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||
|
||||
# Export with dynamic shapes
|
||||
input_ids = torch.zeros((1, 3), dtype=torch.long)
|
||||
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||
input_ids = torch.zeros((1, 3), dtype=torch.long, device=device)
|
||||
cache_position = torch.tensor([0, 1, 2], dtype=torch.long, device=device)
|
||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
|
||||
strict = version.parse(torch.__version__) != version.parse("2.7.0")
|
||||
exported_program = convert_and_export_with_cache(
|
||||
|
||||
Reference in New Issue
Block a user